Skip to content

Instantly share code, notes, and snippets.

@m0o0scar
Created August 15, 2024 02:16
Show Gist options
  • Save m0o0scar/8c2133b056318b6bc34f9f2c5237030a to your computer and use it in GitHub Desktop.
Save m0o0scar/8c2133b056318b6bc34f9f2c5237030a to your computer and use it in GitHub Desktop.
THUDM/LongWriter. Continue this conversation at http://localhost:3000?gist=8c2133b056318b6bc34f9f2c5237030a

[github] THUDM/LongWriter

Source

Python / 4.0K lines of code. LongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs

URL: https://github.com/THUDM/LongWriter

Conversation

▲ how does agentwrite work?

AgentWrite is an automated pipeline for constructing ultra-long output data. It consists of two main steps:

  1. Planning: AgentWrite first takes a long-form writing instruction as input and breaks it down into multiple subtasks. Each subtask represents a paragraph in the final output and includes the main point and word count requirements for that paragraph. This planning step is done by interacting with a large language model (LLM) like GPT-4.

    • Prompt: The prompt used for planning is found in agentwrite/prompts/plan.txt. It instructs the LLM to break down the writing instruction into subtasks, specifying the main point and word count for each paragraph.
    • Code: The planning process is implemented in agentwrite/plan.py. It reads the writing instructions from instructions.jsonl, generates plans using the GPT-4 API, and saves the results to plan.jsonl.
  2. Writing: AgentWrite then uses the generated plan to write each paragraph. It iteratively feeds the LLM with the writing instruction, the plan, and the already written text, prompting it to write the next paragraph based on the current subtask.

    • Prompt: The prompt used for writing is found in agentwrite/prompts/write.txt. It provides the LLM with the writing instruction, the plan, and the already written text, instructing it to write the next paragraph based on the current subtask.
    • Code: The writing process is implemented in agentwrite/write.py. It reads the plans from plan.jsonl, generates paragraphs using the GPT-4 API, and saves the results to write.jsonl.

Overall, AgentWrite leverages the capabilities of LLMs to automate the process of breaking down long-form writing instructions and generating coherent, well-structured outputs.

Key points:

  • LLM-based: AgentWrite relies heavily on LLMs for both planning and writing.
  • Iterative: The writing process is iterative, with the LLM generating one paragraph at a time based on the plan and the already written text.
  • Structured output: The planning step ensures that the final output is well-structured and coherent.
  • Scalable: AgentWrite can be used to generate outputs of various lengths, making it suitable for a wide range of tasks.

Note: The code provided in the repository uses the GPT-4 API. You will need to configure your API key in the plan.py and write.py files to run AgentWrite.

Data Source Content

Full data source content here Project: THUDM/LongWriter

URL: https://github.com/THUDM/LongWriter

Source tree:

THUDM/LongWriter/blob/main
├── LICENSE.txt
├── README.md
├── agentwrite
|  ├── plan.py
|  ├── prompts
|  |  ├── plan.txt
|  |  └── write.txt
|  └── write.py
├── evaluation
|  ├── eval_length.py
|  ├── eval_quality.py
|  ├── judge.txt
|  └── pred.py
├── requirements.txt
├── train
|  ├── README.md
|  ├── dataset.py
|  ├── main.py
|  ├── patch
|  |  ├── modeling_chatglm.py
|  |  └── modeling_llama.py
|  ├── pre_tokenize_glm4.py
|  ├── pre_tokenize_llama3.py
|  ├── sort_and_group.py
|  └── trainer.py
└── trans_web_demo.py

THUDM/LongWriter/blob/main/LICENSE.txt:

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright THU-KEG & Zhipu-AI

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

THUDM/LongWriter/blob/main/README.md:

# LongWriter: Unleashing 10,000+ Word Generation From Long Context LLMs

<p align="center">
    🤗 <a href="https://huggingface.co/datasets/THUDM/LongWriter-6k" target="_blank">HF Repo</a> • 📃 <a href="https://arxiv.org/abs/2408.07055" target="_blank">Paper</a>
</p>

https://github.com/user-attachments/assets/c7eedeca-98ed-43ec-8619-25137987bcde

Left: LongWriter-glm4-9b; Right: GLM-4-9B-chat

## 🔍 Table of Contents
- [⚙️ LongWriter Deployment](#deployment)
- [🤖️ AgentWrite](#agentwrite)
- [🖥️ Model Training](#longwriter-training)
- [📊 Evaluation](#evaluation)
- [👀 Cases](#case)
- [📝 Citation](#citation)

<a name="deployment"></a>
## ⚙️ LongWriter Deployment

**Environmental Setup**:
Install the requirements with pip: `pip install -r requirements.txt`. For Llama-3.1 based models, we recommend using `transformers==4.43.0` or higher version.

We open-source two models: [LongWriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and [LongWriter-llama3.1-8b](https://huggingface.co/THUDM/LongWriter-llama3.1-8b), trained based on [GLM-4-9B](https://huggingface.co/THUDM/glm-4-9b) and [Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B), respectively. These two models point to the "LongWriter-9B-DPO" and "LongWriter-8B" models in our paper. Try the model:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/LongWriter-glm4-9b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()
query = "Write a 10000-word China travel guide"
response, history = model.chat(tokenizer, query, history=[], max_new_tokens=32768, temperature=0.5)
print(response)

You may deploy your own LongWriter chatbot (like the one we show in the teasor video) by running

CUDA_VISIBLE_DEVICES=0 python trans_web_demo.py

🤖️ AgentWrite

agentwrite

We are also open-sourcing AgentWrite under agentwrite/, our automated ultra-long output data construction pipeline. Run plan.py and then write.py to obtain the final data. Please configure your API key in the files.

🖥️ Model Training

You can download and save the LongWriter-6k data through the Hugging Face datasets (🤗 HF Repo):

dataset = load_dataset('THUDM/LongWriter-6k')
for split, split_dataset in dataset.items():
    split_dataset.to_json("train/LongWriter-6k.jsonl")

You can mix it with your own general SFT data. We adopt the code in LongAlign for model training (we use transformers==4.43.0 for training on Llama-3.1), with slight modification to adapt to new models. The training code is under train/. Please make sure to install FlashAttention 2 according to the code base of FlashAttention.

📊 Evaluation

We introduce two evaluation benchmarks: LongBench-Write and LongWrite-Ruler. LongBench-Write focuses more on measuring the long output quality as well as the output length, while LongWrite-Ruler is designed as a light-weight stress test of the model's maximum output length. We provide our evaluation data and code under evaluation/. Run

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python pred.py

to get model responses. Then run python eval_quality.py and python eval_length.py to evaluate the quality ($S_q$) and length ($S_l$) scores. Remember to configure your OpenAI API key in eval_quality.py since we adopt GPT-4o as the judge.

Here are the evaluation results on LongBench-Write: longbench-write Here are the evaluation results on LongWrite-Ruler: longwrite-ruler

👀 Cases

Here are LongWriter-glm4-9b's outputs to random test prompts.

User: Write a tragic love story about a lord's daughter falling in love with a servant, 5000 words.

Assistant: (6176 words)

The Aristocratic Estate

Nestled amidst the rolling hills of the English countryside, the grand estate of Lord Edward Ashwood was a beacon of elegance and sophistication. The sprawling manor, with its towering spires and expansive gardens, stood as a testament to centuries of wealth and power. The Ashwood family, known for their noble lineage and impeccable manners, had been pillars of the local community for generations. The grand halls were filled with the scent of polished wood and rich fabrics, echoing with the soft hum of servants attending to their masters' every need.

Lady Ashwood, the matriarch of the household, was a woman of grace and wisdom. Her beauty had faded with time, but her elegance remained untarnished. She was known for her charitable works and her ability to command respect and admiration from all who met her. Her daughters, Eliza and Arabella, were the epitome of aristocratic charm. Eliza, the elder, was a poised and intelligent young woman, often seen with a book in hand or engaged in scholarly discussions. Arabella, the younger, was vivacious and lively, her laughter ringing through the halls like a bell. Both girls were expected to marry well, their futures already mapped out by their parents.

The household was a bustling hive of activity. The grand dining hall was a testament to the family's wealth, adorned with intricate tapestries and fine china. The servants moved with practiced precision, their duties meticulously assigned. The butler, Mr. Blackwood, was a stern but fair man, ensuring that every aspect of the household ran smoothly. The head cook, Mrs. Brown, was a master of her craft, her culinary skills renowned throughout the county. The young page boys and maids scurried about, their faces a mix of innocence and the early signs of budding adulthood.

The Ashwood estate was a world apart from the simpler lives of the villagers who worked the surrounding lands. The gulf between the two worlds was vast and unbridgeable, a chasm that no one dared to cross. The servants, though integral to the estate's operation, were kept at a distance, their presence a mere background to the grandeur that defined the Ashwoods.

In this world of opulence and order, the seeds of a tragic love story were about to be sown.

The Hidden Heart of Eliza

Eliza Ashwood was the embodiment of aristocratic grace, her every movement a testament to years of refinement and education. Yet, beneath her composed exterior lay a heart yearning for something more, something the grand estate could not provide. Eliza was not merely a bookish daughter, though her love for literature was well-known. She possessed a depth of thought and a sensitivity that set her apart from her peers. Her days were filled with the routine of the estate—participating in social gatherings, learning the intricacies of diplomacy and etiquette, and engaging in scholarly pursuits. But her nights were often spent in quiet reflection, her mind wandering to thoughts that were far removed from the expectations of her station.

Eliza's interest in the estate's servants was a well-guarded secret. She often found herself drawn to the quieter corners of the manor, where the servants conducted their duties with a sense of quiet dignity. It was during one such moment, while taking a solitary walk through the estate's expansive gardens, that she first encountered Thomas. He was a young servant, tasked with tending to the roses that lined the estate's perimeter. His back was to her as he worked, his hands deftly pruning the delicate blooms. Eliza was captivated by the ease and precision with which he moved, a stark contrast to the hurried chaos of the household's daily operations.

As she watched him, a sense of tranquility washed over her. For the first time in what felt like an eternity, Eliza felt a sense of peace. Thomas's presence was unassuming, his focus entirely on his work. It was a stark contrast to the constant hustle and bustle of the manor. Eliza found herself lingering, her curiosity piqued by the man who seemed to exist in a world of his own, untouched by the grandeur that surrounded him.

Their first conversation was innocent and brief, a simple exchange over the health of Thomas's mother. Eliza's genuine concern and Thomas's humble gratitude left an impression on both of them. From that moment, their interactions became more frequent, though always under the watchful eyes of the household. Eliza would often find herself seeking out Thomas, her heart fluttering with a mix of excitement and fear. Their conversations were simple and heartfelt, often centered around the beauty of the garden or the weather. Yet, it was enough to kindle a flame in Eliza's heart, a flame that she knew she must keep hidden.

Eliza's feelings for Thomas were a source of both comfort and conflict. On one hand, he represented a connection to a simpler, more genuine world, a world where she felt truly seen and understood. On the other hand, she was acutely aware of the social and economic chasm that separated them. The idea of a relationship with a servant was not only unimaginable but also scandalous. Eliza's heart was a battlefield of hope and despair, torn between her desire for love and her duty to her family and society.

As their bond deepened, Eliza found herself questioning her place in the world. She had been raised to aspire to greatness, to marry a man of equal or greater status, to uphold the Ashwood legacy. Yet, here she was, falling in love with a servant, a man who represented everything her world sought to exclude. It was a love that defied logic and reason, a love that threatened to upend the carefully constructed life she had been destined to lead.

Eliza's secret affection for Thomas was a delicate balance, a fragile thread that she clung to with all her might. She knew that revealing her feelings would not only shatter her own dreams but also bring disgrace upon her family. Yet, the more time she spent with Thomas, the more she realized that her heart could no longer be contained. It was a love that was as beautiful as it was dangerous, a love that would test the very foundations of her world.

Thomas's Hidden Life

Thomas had been born into a family of humble means, his parents working the fields to make ends meet. From a young age, he had known the hardships of rural life, the backbreaking labor and the sting of poverty. Yet, it was not the struggle that defined him but the resilience and determination he found within himself. When the Ashwood estate advertised for new servants, Thomas saw it as an opportunity to escape the confines of his village and forge a better future for himself.

Life at the Ashwood estate was a stark contrast to his previous existence. The grandeur and opulence were overwhelming at first, but Thomas adapted quickly, his innate sense of duty driving him to excel in his duties. He was assigned to the garden, a role that suited his quiet nature and love for the natural world. Each day, he moved through the estate with a sense of purpose, his hands deftly tending to the plants, his mind a sanctuary of calm amidst the chaos.

Despite the differences in their worlds, Thomas felt a strange sense of belonging at the Ashwood estate. The grand halls and ornate decorations were a world apart from his simple upbringing, but there was a certain elegance to the place that he found comforting. The strict routines and the disciplined lives of the servants provided a structure that he had always craved. He found solace in the predictability of his tasks, in the rhythm of his workday that allowed him moments of quiet reflection.

It was during these moments of solitude that Thomas began to notice Eliza. Her presence was like a gentle breeze through the garden, her elegance contrasting with the rustic beauty of the flowers she admired. Thomas was initially drawn to her beauty, but as their interactions grew, he came to appreciate her intelligence and kindness. Eliza was unlike anyone he had ever met, her conversations filled with wisdom and compassion. She saw beyond his station, into the person he truly was.

Their bond grew slowly, built on shared moments of understanding and mutual respect. Thomas found himself looking forward to their brief conversations, his heart lightening with each encounter. Eliza's genuine interest in his life and her willingness to see past the societal barriers that separated them filled him with a sense of hope he had long thought lost. For the first time, Thomas felt valued and seen, not as a servant, but as a person with dreams and aspirations of his own.

Yet, Thomas was acutely aware of the dangers that their relationship posed. He knew the social and economic chasm that separated them, the impossibility of a future together. The idea of love was a fragile dream in a world that demanded conformity and respectability. Thomas's heart was a battlefield of hope and fear, torn between his desire for Eliza and the reality of their circumstances. He loved her with all his heart, but he was also pragmatic, understanding that their love was a risk he could not afford to take.

As their bond deepened, Thomas found himself questioning his place in the world. He had come to the Ashwood estate seeking a better life, but now he found himself caught in a web of emotions that threatened to unravel everything he had worked for. The love he felt for Eliza was a beautiful but dangerous distraction, a reminder of the dreams he dared not speak aloud. He knew that their love was a fragile thread, one that could easily be severed by the harsh realities of their world.

Thomas's heart was a sanctuary of love and fear, a place where his dreams and reality clashed. He loved Eliza with a passion that defied reason, but he was also realistic, understanding that their love was a fragile hope in a world that demanded conformity. As their bond grew stronger, Thomas found himself standing at a crossroads, his future uncertain and his heart in turmoil. He knew that his love for Eliza was a risk he was willing to take, but he was also aware of the dangers that lay ahead.

The Unspoken Bond

The bond between Eliza and Thomas grew stronger with each passing day, a silent yet powerful connection that neither could deny. Their conversations, though brief, were filled with a depth of understanding that transcended the barriers of their social standing. Eliza found herself looking forward to their encounters, each interaction a source of solace and joy in her otherwise structured and rigid world. Thomas, in turn, felt a sense of belonging and purpose that he had never known before, Eliza's presence a beacon of hope in his otherwise monotonous life.

One evening, as the sun dipped below the horizon, casting a golden glow over the garden, Eliza found herself once again drawn to the roses Thomas tended. This time, she approached him with a quiet determination. "Thomas," she began, her voice barely above a whisper, "I wanted to thank you for always being there, for listening to me when no one else would."

Thomas looked up, his eyes meeting hers with a mixture of surprise and gratitude. "It's my pleasure, Miss Eliza. You've always been kind to me."

Eliza took a deep breath, her heart pounding in her chest. "Thomas, there's something I need to tell you. I... I care for you deeply. More than just as a friend or a servant. I... I love you."

Thomas's eyes widened, his heart racing at the weight of her words. He had felt the same but had never dared to voice them, afraid of the consequences. "Eliza," he whispered, his voice trembling, "I... I feel the same. But we both know the dangers of what we feel."

Eliza's eyes filled with a mixture of hope and despair. "I know the risks, Thomas. I've thought about it night and day. But I can't deny my feelings any longer. I need you to know how I truly feel."

Thomas's heart ached with the weight of her words. He loved her with every part of his being, but the reality of their situation loomed large. "Eliza, we come from two different worlds. The gap between us is vast and unbridgeable. If we pursue this, we risk everything—our futures, our families, our very lives."

Eliza's eyes glistened with unshed tears. "I understand that, Thomas. But I can't live without you. I need you in my life, even if it means defying everything I've ever been taught."

Thomas took a step closer, his hand reaching out to gently touch her cheek. "Eliza, you mean more to me than words can express. But we must be careful. The world is not kind to those who defy its rules."

Eliza nodded, her heart heavy with the weight of their reality. "I know, Thomas. But I'm willing to face whatever comes. I love you, and I can't let that go."

Their fingers brushed against each other's, a silent promise of the love they shared. It was a love that defied reason and societal norms, a love that both terrified and inspired them. They knew the risks, but they were also aware of the beauty and depth of the bond they had forged.

As they stood there under the golden light of the setting sun, their hearts beat in unison, a testament to the love that had grown between them. They were two souls entwined by fate, their love a fragile yet resilient thread that defied the world's expectations. Together, they faced the uncertainties of their future, their hearts united by a love that was as powerful as it was dangerous.

The Struggle Within

Eliza's heart was a tempest of emotions, torn between her love for Thomas and the societal expectations that loomed over her. She knew that her feelings for Thomas were real, profound, and unshakable. Yet, the weight of her upbringing and the expectations of her family were a constant reminder of the peril she faced. Every day was a battle within herself, a struggle to reconcile her heart with her duty.

Eliza's parents had always been strict in their expectations, instilling in her a sense of responsibility and obligation to the Ashwood legacy. They had planned her future, envisioning a marriage that would secure the family's status and wealth. The idea of defying those plans, of pursuing a love that defied societal norms, was a terrifying prospect. Eliza feared the disgrace that would come to her family, the ruin of her carefully constructed life.

Yet, every moment spent with Thomas only deepened her resolve. His kindness, his intelligence, and his unwavering support filled a void in her heart that nothing else could. Eliza found herself longing for the simple, genuine moments they shared, the comfort of his presence and the joy that his love brought into her life. She realized that her happiness, her true happiness, lay in the love she felt for Thomas, not in the expectations of her family.

Eliza's internal conflict was a relentless torment. She loved her family and respected their wishes, but she also loved Thomas with a passion that she could no longer ignore. She spent sleepless nights wrestling with her emotions, her mind a whirlwind of doubt and determination. She knew that she could not continue living a lie, that she had to make a choice between her heart and her duty.

One evening, as she stood before her mirror, staring at her reflection, Eliza made a decision. She would speak to her parents, reveal her feelings for Thomas, and face the consequences. She knew it would be a difficult conversation, but she also knew that she could no longer live in silence. She owed it to herself and to Thomas to be honest about her heart.

As she prepared to face her parents, Eliza's heart was filled with a mixture of fear and hope. She was ready to fight for her love, to defy the expectations that had been laid out for her. She was willing to face whatever consequences might come, as long as she could be true to herself and to the man who had captured her heart.

The Heart-Wrenching Confession

Eliza took a deep breath, steeling herself for the conversation that was about to unfold. She found her parents in the drawing room, their faces a mask of calm as they sipped their tea. Eliza's heart pounded in her chest as she approached them, her hands trembling slightly. "Father, Mother," she began, her voice steady but filled with emotion, "I need to talk to you about something important."

Lady Ashwood set down her teacup, her eyes narrowing slightly. "What is it, Eliza? Out with it."

Eliza took a deep breath, her eyes meeting her parents' with a mixture of determination and fear. "I... I have something to confess. I... I am in love."

Both of her parents' faces paled, their expressions shifting from surprise to concern. Lord Ashwood's stern gaze softened slightly, while Lady Ashwood's eyes filled with a mixture of shock and worry. "Eliza, what do you mean? With whom?" Her voice was a hushed whisper, as if speaking too loudly might shatter the fragile reality they all lived in.

Eliza's heart ached as she forced herself to speak the words that had been burning in her chest for so long. "I am in love with Thomas, the servant who tends the garden. I... I can't deny it any longer. I love him with all my heart."

The room fell into a heavy silence, the air thick with the weight of unspoken words. Lord Ashwood's face darkened, his eyes narrowing as he absorbed her words. "Thomas? The servant? Eliza, this is unthinkable. How could you even entertain such a notion?"

Lady Ashwood's eyes filled with tears, her voice trembling. "Eliza, you must be mistaken. This cannot be real. You must forget him immediately."

Eliza's heart broke as she looked at her parents, seeing the pain and disappointment in their eyes. She knew that their reaction was inevitable, but the words still felt like a dagger to her soul. "I am not mistaken, Mother. I am in love with Thomas, and I cannot change that. I... I cannot live a lie any longer."

Lord Ashwood's voice was cold and stern. "Eliza, you must understand the consequences of your actions. This is not just about you. It is about the Ashwood legacy, about our reputation and honor. You cannot throw all of that away for a mere servant."

Eliza's eyes filled with tears, her heart aching with the weight of their words. "I know, Father. I know the risks. But I cannot live without Thomas. He has given me a love that I never thought possible, a love that makes me feel alive. I cannot deny my feelings any longer."

Lady Ashwood's voice broke, her tears flowing freely. "Eliza, you don't understand. This is more than just a love affair. It is a betrayal of everything we stand for. You are putting our entire family at risk."

Eliza's heart was in turmoil, torn between her love for Thomas and her duty to her family. She knew that her parents were right, that the consequences of her actions could be devastating. But she also knew that she could not live a life of lies, that she had to be true to herself and to the man who had captured her heart.

As the conversation continued, Eliza's resolve only strengthened. She knew that she would face consequences, but she was willing to bear them. She would fight for her love, even if it meant losing everything.

The Unraveling of Dreams

The days following Eliza's confession were a whirlwind of turmoil and heartache. Lord and Lady Ashwood were determined to put an end to what they saw as a scandalous relationship, and their actions were swift and unforgiving. Eliza was forbidden from seeing Thomas, her movements closely monitored by the household staff. The garden, once a sanctuary for their secret meetings, became a place of dread, its beauty marred by the weight of their separation.

Eliza's heart was in shambles, each day a painful reminder of the love she had been forced to abandon. She spent her nights in tears, her mind a constant echo of the words spoken by her parents. The walls of the grand estate seemed to close in around her, the once familiar surroundings now a prison of her own making. Her once vibrant spirit dimmed, her laughter replaced by a hollow echo of its former self.

Thomas, too, was not immune to the devastation. He could sense the change in Eliza, the sadness that clouded her eyes and weighed heavily on her heart. He longed to reach out to her, to offer her the comfort she so desperately needed, but he knew that any attempt would only worsen her situation. His heart ached with every passing day, his love for Eliza a silent but relentless torment.

The household was abuzz with the news of Eliza's forbidden love, the servants whispering behind closed doors and the guests at social gatherings casting judgmental glances her way. Eliza's reputation was under siege, her once impeccable standing now tarnished by the scandal. The pressure was immense, and she found herself questioning her every action, her every decision.

Despite the distance imposed upon them, Eliza and Thomas found ways to communicate. Secret notes were passed through the estate, their words a lifeline in the storm of their separation. Each letter was a balm to their wounded hearts, a reminder of the love that had once brought them together. But even these brief exchanges were fraught with danger, each note a potential discovery that could lead to even greater consequences.

Eliza's parents, relentless in their efforts to break her spirit, increased the pressure. They sought to distract her with social engagements and scholarly pursuits, hoping to divert her attention from Thomas. But Eliza's heart remained steadfast, her love for Thomas an unyielding force that defied their attempts to sever their bond.

The strain of their situation began to take its toll on Eliza's health. She grew pale and weak, her once vibrant energy sapped by the emotional and physical weight of her plight. The Ashwood estate, once a symbol of her family's power and prestige, now felt like a gilded cage, confining her spirit and threatening to crush her soul.

Thomas, watching from a distance, felt the same despair. He saw the pain in Eliza's eyes, the way her once bright smile had faded into a shadow of its former self. His heart ached with every passing day, his love for her a beacon of hope in the darkness that had enveloped her life.

As the days turned into weeks, the love between Eliza and Thomas remained unbroken, a testament to their unwavering devotion to each other. They faced their trials with courage and resilience, their bond stronger for the hardships they endured. But the weight of their circumstances was a constant reminder of the dangers that loomed over them, a reminder that their love, while powerful, was also fragile and vulnerable to the world's judgment.

The Breaking Point

The weight of their secret love became too much for Eliza to bear. The constant pressure from her parents, the judgment of society, and the physical and emotional toll of their separation began to erode her spirit. One evening, as the moon cast a silvery glow over the Ashwood estate, Eliza made a desperate decision. She would run away, escape the confines of her life and the expectations that had been laid upon her.

Eliza slipped out of the manor under the cover of darkness, her heart pounding with a mixture of fear and determination. She made her way to the garden, where Thomas had promised to meet her. The garden, once a place of solace, now felt like a battlefield, each step she took a step closer to the unknown.

Thomas was waiting for her, his face a mask of concern and hope. "Eliza, I knew you would come," he whispered, his voice trembling with emotion. "I've been so worried about you."

Eliza's eyes filled with tears as she threw her arms around him, her body shaking with the weight of her emotions. "Thomas, I can't stay any longer. The pressure is too much. I need to be with you, even if it means losing everything."

Thomas held her close, his heart aching with the depth of her words. "Eliza, I understand. I've felt the same. But we must be careful. Running away is not a solution. It will only make things worse."

Eliza pulled back, her eyes searching Thomas's face for reassurance. "I know, Thomas. But I can't go back. I can't live a lie any longer. I need you, and I need to be free to love you."

Thomas's heart broke as he looked into her eyes, seeing the pain and determination that burned within them. He knew that Eliza was right, that running away was not the answer, but he also knew that he could not bear to see her suffer any longer. "Eliza, I love you more than words can express. But we must be smart about this. We need a plan, a way to make a life together without the world's judgment."

Eliza nodded, her tears drying on her cheeks as she took a deep breath. "I trust you, Thomas. I know you will guide me. I just need to be with you, to feel your love and support."

Thomas led Eliza to a secluded part of the garden, away from the prying eyes of the household staff. He took her hands in his, his gaze steady and resolute. "Eliza, I promise you that I will do everything in my power to protect you and to build a life together. But we must be patient, and we must be careful. We cannot rush into anything without a solid plan."

Eliza's heart swelled with hope and love as she looked into Thomas's eyes. "I believe in you, Thomas. I believe that we can find a way, that our love can overcome everything. I just need you."

Thomas's eyes softened as he pulled Eliza into his arms, holding her close as if to shield her from the world's dangers. "I love you, Eliza. More than anything. And I will do whatever it takes to be with you, to build a life that is truly ours. But we must be strong, and we must be smart."

Eliza's heart felt lighter as she rested her head on Thomas's chest, listening to the steady beat of his heart. She knew that their journey would be fraught with challenges, but she also knew that their love was strong enough to overcome them. "I love you too, Thomas. And I will stand by you no matter what."

As they stood there in the moonlit garden, their hearts beating in unison, Eliza and Thomas made a silent vow to each other. They would face the world together, their love a beacon of hope in the darkness. They would fight for their happiness, no matter the cost.

The Ultimate Sacrifice

The days that followed were a blur of desperation and hope. Eliza and Thomas knew that their time was limited, that the authorities would not be long in finding them. They spent their days in hiding, moving from one safe house to another, always looking over their shoulders for the threat of discovery. Each night, they found solace in each other's arms, their love a fragile yet resilient thread that held them together in the face of adversity.

As the authorities closed in, Eliza and Thomas knew that their time was running out. They had to make a decision, a choice that would determine the course of their lives. Eliza looked into Thomas's eyes, her heart heavy with the weight of their reality. "Thomas," she whispered, her voice trembling, "we can't keep running. They will find us eventually. We need a plan, a way to ensure that our love endures, even if we cannot be together."

Thomas's eyes filled with a mixture of love and sorrow. "Eliza, I have been thinking about this. There is only one way to ensure that our love endures, that our sacrifice is not in vain. We must marry, legally. It is our only chance to be together, to build a life that is truly ours."

Eliza's eyes widened in shock and fear. "Thomas, I can't marry you. It is illegal, and it would mean losing everything—our freedom, our family, our future. I can't ask you to do that."

Thomas's grip on her hand tightened, his voice firm and resolute. "Eliza, you must understand. Marrying me is the only way we can be together, the only way to ensure that our love endures. I am willing to face the consequences, to give up everything for you. But I need you to be with me, to stand by my side."

Eliza's heart ached with the weight of Thomas's words. She knew that he was right, that their love was worth any sacrifice. But the idea of losing everything, of defying society's expectations, was a daunting prospect. "Thomas, I love you more than anything. But I am afraid of what will happen if we marry. I am afraid of losing you, of losing everything we have built together."

Thomas's eyes softened as he pulled Eliza into his arms, holding her close as if to shield her from the world's dangers. "Eliza, you must trust me. I will protect you, I will fight for us. We can make this work, we can build a life together. But we must be brave, and we must be united."

Eliza's heart swelled with love and determination as she looked into Thomas's eyes. She saw the depth of his devotion, the unwavering commitment he had for her. She knew that he was right, that their love was worth any sacrifice. "Thomas, I trust you. I will marry you, and I will stand by your side no matter what."

As they stood there in the dim light of their hidden sanctuary, Eliza and Thomas made their final vow to each other. They would marry, defy society's expectations, and fight for their love. They would face the world together, their hearts united by a love that was as powerful as it was dangerous.

Their wedding was a secret ceremony, held in the early hours of dawn to avoid detection. They exchanged vows in a small, secluded chapel, their hearts beating in unison as they promised to love and support each other through every trial and tribulation. The ceremony was simple but profound, a testament to their love and their commitment to each other.

As they emerged from the chapel, hand in hand, Eliza and Thomas knew that their journey had only just begun. They faced the world with a newfound sense of purpose and determination, their love a beacon of hope in a world that sought to crush them. They were ready to face whatever challenges lay ahead, united by a love that was as powerful as it was fragile.

The Heart-Wrenching End

The authorities discovered Eliza and Thomas's marriage soon after the ceremony. The news spread like wildfire through the estate and the surrounding villages, igniting a firestorm of outrage and scandal. Lord and Lady Ashwood were devastated, their pride and reputation shattered by their daughter's defiance. Eliza was immediately confined to her room, her parents refusing to speak to her or acknowledge her existence.

Thomas, determined to protect Eliza, stood by her side, his love unwavering despite the danger it posed to him. He was dismissed from his position at the estate, his future prospects ruined. The villagers turned their backs on him, their judgmental stares a constant reminder of the price he had paid for his love.

Eliza's heart was in tatters, her spirit broken by the weight of her parents' rejection and the world's condemnation. She spent her days in solitude, her once vibrant spirit dimmed by the sorrow that enveloped her. Thomas, though strong and resilient, could see the despair in her eyes, the pain that gnawed at her soul.

One evening, as the moon cast a silvery glow over the estate, Thomas took Eliza into the garden, the same place where their love had blossomed. He held her close, his voice trembling with emotion. "Eliza, I am so sorry. I never wanted this to happen. I love you more than anything, and I would do anything to make things right."

Eliza's eyes filled with tears, her voice a whisper. "Thomas, I know. But I can't bear the pain any longer. I can't live in this world of judgment and rejection. I need to be free, to find peace."

Thomas's heart broke as he looked into Eliza's eyes, seeing the pain and determination that burned within them. He knew that she was right, that she needed to be free from the world's constraints. "Eliza, please don't do this. I can't live without you. I love you, and I will always love you."

Eliza pulled back, her eyes searching Thomas's face for one last moment of solace. "Thomas, I love you too. But I need to be free, to find the peace that has eluded me. I am so sorry."

Thomas's heart ached with the weight of Eliza's words. He knew that he could not stop her, that her determination was unwavering. "Eliza, I will always love you, no matter what. But I need you to be happy, to find the peace that you deserve."

Eliza's eyes glistened with unshed tears as she looked into Thomas's eyes, her heart aching with the weight of her decision. "Thomas, I will always love you. And I will never forget the love we shared. I am so sorry for the pain I have caused you."

Thomas's eyes filled with tears as he pulled Eliza into his arms, holding her close as if to shield her from the world's dangers. "Eliza, you have given me the greatest gift of my life—the love of a kind and beautiful soul. I will always cherish that love, no matter what."

As they stood there in the moonlit garden, their hearts beating in unison, Eliza and Thomas made their final farewell. They knew that their love had been powerful and beautiful, a testament to the strength of the human spirit. But they also knew that their love had come at a great cost, a sacrifice that had forever changed their lives.

Eliza took a deep breath, her heart heavy with the weight of her decision. "Thomas, I need to go. I need to find the peace that has eluded me. I will always love you, and I will never forget you."

Thomas's eyes filled with tears as he released Eliza, his heart aching with the weight of her departure. "Eliza, be happy, be free. I will always love you, and I will always remember the love we shared."

Eliza walked away from Thomas, her heart breaking with each step. She knew that she was leaving a piece of her heart behind, that her love for Thomas would forever be a part of her. But she also knew that she needed to find the peace that had eluded her, to build a life that was truly her own.

Thomas watched Eliza disappear into the darkness, his heart aching with the weight of her departure. He knew that their love had been powerful and beautiful, a testament to the strength of the human spirit. But he also knew that their love had come at a great cost, a sacrifice that had forever changed their lives.

As the moon cast its silvery glow over the garden, Eliza and Thomas's love story came to an end. Their love had been powerful and beautiful, a testament to the strength of the human spirit. But it had also been tragic and heartbreaking, a reminder of the dangers that love can face in a world that often seeks to crush it.


User: 写一篇10000字的中国旅游指南

Assistant: (10691字)

中国旅游概述

中国,这片古老而神秘的东方土地,拥有着悠久的历史和丰富的文化。作为世界上人口最多的国家之一,中国以其多样的地理景观、独特的文化遗产和热情好客的人民而闻名于世。无论是繁华的都市、壮丽的自然风光,还是深厚的文化底蕴,都吸引着无数游客前来探访。

中国旅游的魅力首先体现在其多样化的地理景观上。从北国风光的雪域高原到南国海滨的碧海蓝天,从西部大漠的辽阔无垠到东部平原的绿意盎然,中国几乎囊括了所有类型的自然景观。游客可以在这里体验到从寒冷到炎热、从干燥到湿润的各种气候和地貌,满足对自然风光的所有想象。

其次,中国的文化遗产是其旅游的一大亮点。中国有着五千多年的文明史,留下了无数珍贵的文化遗产。从宏伟的古代建筑,如长城、故宫、兵马俑,到古老的寺庙和园林,如少林寺、苏州园林,每一个景点都承载着深厚的历史和文化内涵。这些文化遗产不仅是历史的见证,更是中华民族智慧的结晶,吸引着无数游客前来参观和学习。

此外,中国的人民也是其旅游魅力的重要组成部分。中国人民热情好客,乐于助人,让游客在旅途中感受到家的温暖。无论是大城市的热闹街头,还是小村庄的宁静小巷,人们总是愿意分享他们的故事和经验,让游客更好地了解和体验中国的文化。

总之,中国以其独特的自然风光、丰富的文化遗产和热情好客的人民,成为全球游客向往的旅游目的地。无论是历史爱好者、自然探险家还是文化体验者,都能在中国找到属于自己的旅游天堂。

中国的行政区划

中国是一个地域辽阔、民族众多的国家,其行政区划复杂多样。全国共有34个省级行政区,包括23个省、5个自治区、4个直辖市和2个特别行政区。每个省级行政区都有其独特的地理、文化和历史背景,为游客提供了丰富多彩的旅游选择。

首先,23个省是中国行政区划中的主体部分。这些省份包括山东、河南、四川、湖南、湖北、江西、广东、广西、贵州、云南、陕西、甘肃、青海、宁夏、新疆等。每个省份都有其独特的自然风光和文化特色。例如,山东省以悠久的历史和丰富的文化遗产著称,拥有泰山、曲阜孔庙等著名景点;河南省则是中华文明的重要发源地,拥有龙门石窟、清明上河园等历史文化景点;四川省以其秀美的自然风光和独特的美食文化闻名,九寨沟、峨眉山等自然景观吸引了大量游客。

其次,5个自治区分别是内蒙古自治区、广西壮族自治区、西藏自治区、宁夏回族自治区和新疆维吾尔自治区。这些自治区以其独特的民族文化和自然景观而著称。例如,内蒙古自治区拥有广袤的草原和壮观的沙漠景观,是骑马和草原探险的理想之地;广西壮族自治区以其秀美的山水和丰富的民族文化而闻名,桂林山水、漓江等自然景观令人陶醉;西藏自治区则是世界屋脊,拥有珠穆朗玛峰、布达拉宫等壮丽的自然和人文景观,是朝圣和探险的绝佳选择。

4个直辖市包括北京、上海、天津和重庆。这些城市都是中国的重要经济、文化和交通中心,拥有丰富的历史遗迹和现代化建筑。例如,北京市是中国的首都,拥有故宫、长城等世界著名的文化遗产;上海市是中国的经济中心,拥有东方明珠塔、外滩等现代化建筑和商业区;天津市则以其独特的港口文化和历史遗迹而闻名;重庆市以其壮丽的山水和独特的火锅文化吸引了大量游客。

最后,2个特别行政区是香港和澳门。这两个城市以其独特的国际化和现代化风貌而著称。香港是国际金融中心和购物天堂,拥有维多利亚港、迪士尼乐园等著名景点;澳门则以其博彩业和葡萄牙殖民历史而闻名,大三巴牌坊、威尼斯人度假村等景点吸引了大量游客。

总之,中国的行政区划复杂多样,每个省级行政区都有其独特的地理、文化和历史背景,为游客提供了丰富多彩的旅游选择。无论是自然风光、历史文化还是现代都市,中国都能满足不同游客的旅游需求。

中国的气候特点

中国的气候类型多样,从北到南,从东到西,气候差异显著。总体来说,中国可以分为寒带、温带、亚热带和热带四种主要气候类型。北方地区冬季寒冷,夏季炎热,而南方地区则四季如春,温暖宜人。这种多样的气候特点不仅为中国的农业生产提供了丰富的条件,也为游客提供了多样化的旅游体验。

在北方,冬季寒冷漫长,尤其以东北地区最为典型。哈尔滨的冰雪节吸引了大量游客前来观赏冰雕和体验冰雪运动。夏季则相对短暂,但气温较高,适合避暑旅游。东北地区的大兴安岭、长白山等地区,夏季绿意盎然,是避暑度假的好去处。

中部地区包括华北、华东和华中地区,主要气候类型为温带季风气候。冬季寒冷干燥,夏季炎热多雨。北京、上海和武汉等大城市都位于这一区域。冬季的北京,游客可以参观故宫、长城等著名景点,体验寒冷中的历史文化魅力;夏季的上海,游客可以游览外滩、城隍庙,享受都市的繁华和清凉。

南方地区主要气候类型为亚热带和热带季风气候。这里四季分明,冬季温暖,夏季炎热潮湿。广东、广西、福建和海南等省份都属于这一区域。冬季的海南,温暖如春,是避寒度假的理想之地;夏季的厦门,游客可以游览鼓浪屿、南普陀寺,享受海滨的清凉和美景。

西部地区包括新疆、西藏、青海和甘肃等省份,气候类型多样,从寒带到高原气候都有。新疆的冬季寒冷干燥,夏季则炎热多风,是探险和草原旅游的好去处。西藏的气候则以高原气候为主,冬季寒冷,夏季凉爽,是朝圣和登山的热门目的地。青海湖、茶卡盐湖等自然景观在夏季尤其迷人。

总体来说,中国的气候特点为游客提供了多样化的旅游选择。不同季节的气候条件适合不同的旅游活动,游客可以根据自己的喜好和需求选择合适的旅行时间。例如,冬季适合前往北方和西北地区体验冰雪风光,夏季则适合前往南方和东南沿海地区避暑度假。了解中国的气候特点,有助于游客更好地规划旅行路线,享受愉快的旅程。

中国的主要旅游景点

中国拥有众多世界著名的旅游景点,这些景点不仅代表了中国丰富的历史和文化,也展示了其独特的自然风光。以下是一些中国最具代表性的主要旅游景点:

长城

长城,作为中国最著名的象征之一,是世界上最长的军事防御工程。它蜿蜒于中国北方的高山和丘陵之间,全长约21,196公里。游客可以参观多个著名的长城段,如八达岭长城、慕田峪长城和嘉峪关长城。八达岭长城是保存最完整的一段,也是游客最多的地方;慕田峪长城则以其险峻和优美的风景著称;嘉峪关长城则展示了长城的西北风格。无论在哪个季节,长城都是游客拍照和体验中国历史的重要地点。

故宫

故宫,也称为紫禁城,是明朝和清朝两代皇宫,位于北京市中心。这座宏伟的宫殿建筑群占地约72万平方米,拥有9000多间房屋。故宫不仅以其壮观的建筑和精美的装饰闻名,还收藏了大量的珍贵文物和艺术品。游客可以参观太和殿、乾清宫、养心殿等主要建筑,了解中国古代宫廷生活和文化。故宫每年吸引着数百万游客,是了解中国历史和文化的重要窗口。

兵马俑

兵马俑,位于陕西省西安市,是中国古代秦朝的军事遗迹。1974年,兵马俑的发现震惊了世界,成为中国古代文明的重要象征。兵马俑坑内出土了数千个陶制的士兵和马俑,每个陶俑都有独特的面部表情和姿态,展示了秦朝的军事力量和工艺水平。兵马俑博物馆是游客了解秦朝历史和文化的重要场所,每年吸引着大量国内外游客。

西湖

西湖,位于浙江省杭州市,是中国最著名的湖泊之一,被誉为“人间天堂”。西湖以其秀美的自然风光和丰富的文化底蕴而闻名。湖面上点缀着三潭印月、雷峰塔、断桥等著名景点,湖畔的苏堤、白堤等园林更是美不胜收。春季的西湖,桃花盛开,春色满园;夏季的西湖,荷花绽放,清凉宜人;秋季的西湖,红叶满山,秋意盎然;冬季的西湖,雪景如画,别有一番风味。

桂林山水

桂林山水,位于广西壮族自治区,以其壮丽的山水风光和独特的喀斯特地貌而著称。漓江是桂林山水的代表,游客可以乘坐竹筏或游船游览漓江,欣赏两岸的奇峰异石和清澈的江水。桂林的象山、七星公园、芦笛岩等景点也吸引了大量游客。桂林山水不仅美在自然,还蕴含着丰富的文化内涵,是游客体验中国山水文化的绝佳之地。

九寨沟

九寨沟,位于四川省阿坝藏族羌族自治州,是中国最著名的自然景区之一。九寨沟以其多彩的湖泊、瀑布、雪山和森林而闻名,被誉为“童话世界”。景区内的珍珠滩、五花海、长海等景点美不胜收,吸引了无数游客前来观赏。九寨沟的四季各有特色,春季的桃花和杜鹃花,夏季的清凉绿意,秋季的红叶和金黄,冬季的雪景,都让人流连忘返。

布达拉宫

布达拉宫,位于西藏自治区的首府拉萨,是西藏最著名的地标建筑。这座宏伟的宫殿建于公元7世纪,是藏传佛教的象征。布达拉宫依山而建,共有13层,高117米,是西藏历史和文化的象征。游客可以参观布达拉宫的各个殿堂,了解藏传佛教的历史和文化。布达拉宫不仅是宗教信仰的中心,也是游客了解西藏文化的重要窗口。

长白山

长白山,位于吉林省东南部,是中国东北的著名山脉。长白山以其壮丽的自然风光和丰富的生物多样性而著称。长白山天池是世界上最深的高山湖泊之一,湖面海拔达2189米,湖水清澈碧蓝。此外,长白山的瀑布、温泉和原始森林等自然景观也吸引了大量游客。长白山还是中国著名的滑雪胜地,冬季的滑雪活动吸引了众多滑雪爱好者。

西湖

西湖,位于浙江省杭州市,是中国最著名的湖泊之一,被誉为“人间天堂”。西湖以其秀美的自然风光和丰富的文化底蕴而闻名。湖面上点缀着三潭印月、雷峰塔、断桥等著名景点,湖畔的苏堤、白堤等园林更是美不胜收。春季的西湖,桃花盛开,春色满园;夏季的西湖,荷花绽放,清凉宜人;秋季的西湖,红叶满山,秋意盎然;冬季的西湖,雪景如画,别有一番风味。

张家界

张家界,位于湖南省西北部,是中国著名的自然风景区。张家界以其独特的石柱地貌和壮丽的山水风光而闻名。景区内的袁家界、天子山、金鞭溪等景点美不胜收,吸引了无数游客前来观赏。张家界还是中国著名的电影拍摄地,如《阿凡达》等电影在此取景,使其名声大噪。张家界不仅自然风光优美,还拥有丰富的民俗文化和历史遗迹,是游客体验中国南方自然风光和文化的绝佳之地。

颐和园

颐和园,位于北京市西郊,是中国古代皇家园林之一。颐和园建于清朝乾隆年间,占地面积达290公顷,是中国园林艺术的杰作。园内拥有万寿山、昆明湖、长廊、佛香阁等著名景点,建筑精美,景色宜人。颐和园不仅展示了中国的园林艺术,还蕴含了丰富的历史文化内涵,是游客了解中国古代皇家生活和文化的绝佳场所。

西安城墙

西安城墙,位于陕西省西安市,是中国现存最完整的古代城墙之一。城墙建于明朝,全长约14公里,高约12米,厚约18米。游客可以登上城墙,俯瞰整个西安市区,感受古代都城的雄伟气势。城墙周边还有许多历史文化景点,如兵马俑、大雁塔等,是游客了解中国古代历史和文化的重要场所。

这些主要旅游景点不仅展示了中国的自然风光和历史文化,也为游客提供了丰富多彩的旅游体验。无论您是历史爱好者、自然探险家还是文化体验者,中国都有适合您的旅游目的地。

中国的美食文化

中国美食文化博大精深,各地特色美食琳琅满目,每一种都蕴含着独特的地域风情和文化内涵。从北方的面食到南方的米饭,从东部的海鲜到西部的牛羊肉,中国美食的多样性和丰富性令人叹为观止。

北方美食

北方以面食为主,尤其是小麦面食,如北京炸酱面、山西刀削面、陕西油泼面等。北京炸酱面以其独特的酱料和面条口感闻名,是北京的传统美食代表。山西刀削面则以其独特的制作工艺和厚实的面条深受喜爱。陕西油泼面则是以热油浇在面上,香气扑鼻,味道鲜美。

此外,北方还有许多著名的肉类美食,如内蒙古的烤全羊、新疆的烤羊肉串和兰州的牛肉面。内蒙古的烤全羊以其肉质鲜嫩、香气四溢而著称,是蒙古族招待贵宾的最高礼遇。新疆的烤羊肉串则是街头巷尾常见的小吃,香气扑鼻,肉质鲜美,是新疆美食的代表。兰州的牛肉面以其独特的汤料和面条口感闻名,是兰州人的日常美食。

南方美食

南方则以米饭和海鲜为主,如广东的早茶、四川的火锅、江苏的苏帮菜等。广东的早茶文化源远流长,以点心和茶饮为主,如虾饺、烧卖、蛋挞等,是广东人早晨必不可少的饮食。四川火锅以其麻辣鲜香而著称,是四川人喜爱的美食代表,无论是麻辣火锅还是清汤火锅,都能满足不同口味的需求。

江苏的苏帮菜则以清淡鲜美、注重原汁原味而闻名,如扬州炒饭、松鼠桂鱼等。苏帮菜讲究色、香、味、形的和谐统一,是江南美食的代表。此外,南方的海鲜美食也极具特色,如广东的清蒸海鲜、福建的佛跳墙、浙江的西湖醋鱼等。这些海鲜美食以其鲜美的口感和独特的烹饪方法吸引了无数游客。

西部美食

西部地区的美食则以牛羊肉和面食为主,如新疆的抓饭、青海的羊肉泡馍、甘肃的拉面等。新疆的抓饭以其独特的烹饪方法和丰富的食材而著称,是新疆人喜爱的传统美食。青海的羊肉泡馍则是以羊肉汤和馍为主,口感鲜美,是青海的特色美食。甘肃的拉面以其独特的制作工艺和口感闻名,是甘肃人的日常美食。

特色小吃

中国的小吃文化丰富多彩,各地的小吃都有其独特的风味和制作方法。如北京的豆汁、天津的狗不理包子、上海的生煎包、西安的肉夹馍、成都的麻辣串等。这些小吃不仅味道独特,还蕴含着丰富的文化内涵,是游客体验当地生活的重要途径。

总之,中国美食文化丰富多彩,各地的特色美食各具特色,无论是主食还是小吃,都能满足不同游客的口味需求。品尝中国美食,是了解中国文化和生活方式的重要方式之一。

中国的传统文化和节庆活动

中国是一个拥有悠久历史和丰富文化的国家,其传统文化和节庆活动丰富多彩,反映了中华民族的智慧和情感。以下是一些重要的中国传统文化和节庆活动:

春节

春节,又称农历新年,是中国最重要的传统节日,通常在农历正月初一庆祝。春节象征着新的一年的开始,人们通过贴春联、放鞭炮、拜年、吃团圆饭等方式来庆祝。春节期间,家家户户都会打扫卫生,贴上红春联和窗花,营造喜庆的氛围。除夕夜,全家人会团聚在一起吃年夜饭,象征团圆和幸福。春节期间,还有舞龙舞狮、庙会等活动,吸引了大量游客前来体验中国的新年文化。

清明节

清明节,通常在每年的4月4日或5日,是中国传统节日之一,也是祭祖扫墓的日子。清明节期间,人们会前往墓地祭拜祖先,扫墓、献花,表达对先人的敬意和怀念。此外,清明节还有踏青、赏花等活动,人们会到郊外散步,欣赏春天的美景。清明节不仅是一个重要的传统节日,也是人们亲近自然、放松心情的好时机。

端午节

端午节,通常在农历五月初五庆祝,是中国传统节日之一。端午节有吃粽子、赛龙舟等传统习俗。粽子是端午节的传统食品,以糯米和各种馅料制成,形状多样,口味丰富。赛龙舟则是端午节最具代表性的活动,人们会组织龙舟比赛,以纪念古代爱国诗人屈原。端午节不仅是一个庆祝节日,也是传承和弘扬中华民族传统文化的重要时刻。

中秋节

中秋节,通常在农历八月十五庆祝,是中国传统的团圆节日。中秋节以赏月、吃月饼为主要活动。月饼是中秋节的传统食品,形状圆饼,寓意团圆和圆满。人们会在晚上一起赏月,品尝月饼,表达对家人和亲友的思念和祝福。中秋节不仅是一个家庭团聚的日子,也是人们表达情感和思念的重要时刻。

重阳节

重阳节,通常在农历九月初九庆祝,是中国传统节日之一。重阳节有登高、赏菊、吃重阳糕等传统习俗。登高是重阳节的重要活动,人们会到高处远眺,祈求健康和长寿。赏菊则是重阳节的一项重要活动,人们会欣赏各种美丽的菊花,感受秋天的气息。重阳节不仅是一个庆祝节日,也是人们祈求健康和长寿的重要时刻。

其他传统节日

除了上述重要的传统节日外,中国还有许多其他传统节日,如元宵节、清明节、端午节、中秋节等。这些节日都有其独特的庆祝方式和传统习俗,反映了中华民族的智慧和情感。例如,元宵节有赏花灯、猜灯谜的习俗;清明节有踏青、扫墓的习俗;端午节有赛龙舟、吃粽子的习俗;中秋节有赏月、吃月饼的习俗。

总之,中国的传统文化和节庆活动丰富多彩,每一个节日都蕴含着深厚的文化内涵和民族情感。通过参与这些节日活动,游客不仅能感受到中国文化的魅力,还能更好地了解和体验中国的传统生活方式。

中国的购物体验

中国是一个购物天堂,无论是传统的手工艺品还是现代化的购物中心,都能满足游客的购物需求。以下是一些著名的购物地点和特色商品,供游客参考:

北京

北京作为中国的首都,拥有丰富的购物资源。王府井大街是北京最著名的购物街之一,这里有各种国内外品牌和传统手工艺品店。此外,北京还有许多特色市场,如潘家园古玩市场,这里可以找到各种古董、艺术品和手工艺品。

特色商品:北京烤鸭、丝绸、景泰蓝、瓷器、古董。

上海

上海是中国的经济中心,拥有众多高端购物中心和时尚品牌店。南京路步行街是上海最繁华的商业街之一,汇集了各种国内外品牌。此外,上海还有豫园商城,这里可以找到许多传统手工艺品和特色商品。

特色商品:丝绸、上海旗袍、珠宝、手表、茶叶。

成都

成都是一个充满生活气息的城市,拥有许多特色市场和购物中心。宽窄巷子是成都的一个著名景点,也是购物的好去处。这里可以找到许多传统手工艺品和特色小吃。此外,成都还有春熙路、太古里等现代化购物中心。

特色商品:火锅底料、麻辣串、蜀锦、蜀绣、茶叶。

西安

西安作为古都,拥有丰富的历史文化遗产和传统手工艺品。回民街是西安的一个著名景点,也是购物的好去处。这里可以找到各种传统小吃和手工艺品,如剪纸、陶器、皮影等。

特色商品:兵马俑复制品、剪纸、陶器、皮影、羊肉泡馍调料。

桂林

桂林以其美丽的自然风光和丰富的民俗文化而闻名。在桂林的市区,游客可以找到许多特色商店和手工艺品店,如东西巷、正阳步行街等。这里可以购买到各种桂林特色商品,如桂林米粉、漓江石画、竹编等。

特色商品:桂林米粉、漓江石画、竹编、茶叶。

香港

香港是国际化的购物天堂,拥有世界各地的品牌和特色商品。香港的购物中心如铜锣湾、尖沙咀等地,是购物的好去处。此外,香港还有许多著名的购物街,如苏豪区、中环等。

特色商品:珠宝、手表、化妆品、丝绸、电子产品。

澳门

澳门以其独特的葡萄牙文化和丰富的美食而闻名。在澳门的购物区,如新葡京购物中心、威尼斯人购物中心等,游客可以找到各种国际品牌和特色商品。

特色商品:珠宝、手表、化妆品、葡式糕点、葡萄酒。

总之,中国的购物体验丰富多彩,无论是传统手工艺品还是现代化购物中心,都能满足游客的购物需求。游客可以根据自己的兴趣和需求,选择适合自己的购物地点和特色商品。

中国的住宿选择

中国拥有丰富的住宿选择,从豪华酒店到经济型旅馆,从传统客栈到青年旅舍,无论您的预算和喜好如何,都能找到合适的住宿地点。以下是一些常见的住宿类型和推荐:

豪华酒店

豪华酒店通常位于城市中心或旅游景点附近,提供高品质的服务和设施。这些酒店通常拥有宽敞的客房、豪华的浴室、高级餐饮服务和健身中心等。例如,北京的中国大饭店、上海的浦东香格里拉大酒店、成都的成都希尔顿酒店等,都是豪华酒店的代表。豪华酒店适合追求舒适和品质的游客,尤其是商务旅行者和高端旅游者。

经济型旅馆

经济型旅馆价格适中,提供基本的住宿设施,如干净整洁的客房、热水和空调等。这些旅馆通常位于城市的繁华地段或旅游景点附近,方便游客出行。例如,北京的如家快捷酒店、上海的7天连锁酒店、成都的汉庭酒店等,都是经济型旅馆的典型代表。经济型旅馆适合预算有限的游客,尤其是背包客和短期旅行者。

传统客栈

传统客栈是中国特色的一种住宿方式,通常位于古城或古镇中,建筑风格古朴,内部装饰简约而典雅。这些客栈不仅提供住宿,还常常提供当地特色餐饮和旅游咨询服务。例如,丽江的四方街客栈、苏州的平江客栈、西安的回民街客栈等,都是传统客栈的典型代表。传统客栈适合喜欢体验当地文化和风情的游客。

青年旅舍

青年旅舍是一种经济实惠的住宿选择,通常提供多人间和私人房间,适合年轻人和背包客。这些旅舍设施简单,但通常设有公共厨房、客厅和活动室等,方便游客交流和互动。例如,北京的青年旅舍、上海的青年旅舍、厦门的鼓浪屿青年旅舍等,都是青年旅舍的代表。青年旅舍适合喜欢结交新朋友和体验当地文化的年轻人。

推荐住宿

  1. 北京:中国大饭店、如家快捷酒店、南锣鼓巷传统客栈。
  2. 上海:浦东香格里拉大酒店、7天连锁酒店、豫园传统客栈。
  3. 成都:成都希尔顿酒店、汉庭酒店、宽窄巷子传统客栈。
  4. 丽江:四方街客栈、丽江古城青年旅舍。
  5. 苏州:平江客栈、苏州园林酒店。
  6. 厦门:鼓浪屿青年旅舍、厦门国际酒店。

总之,中国的住宿选择丰富多样,无论您的预算和喜好如何,都能找到合适的住宿地点。选择适合自己的住宿,能让您的旅行更加舒适和愉快。

中国的旅游交通

中国的交通网络发达,为游客提供了便捷的出行方式。无论是国内航班、火车还是长途汽车,都能满足不同游客的出行需求。以下是一些主要的交通方式及其特点:

国内航班

国内航班是中国游客长途旅行的主要交通方式之一。中国拥有众多大型机场,如北京首都国际机场、上海浦东国际机场、广州白云国际机场等,连接了国内外各大城市。国内航班通常快捷方便,适合远距离旅行。游客可以通过各大航空公司官网或在线旅行社预订机票,通常提前预订可以获得更优惠的价格。需要注意的是,国内航班通常需要提前办理登机手续和行李托运,建议游客提前到达机场,以免耽误行程。

火车

中国的铁路网络非常发达,覆盖了全国大部分地区。中国的高速铁路(简称“高铁”)系统尤为先进,运行速度快,舒适度高,是长途旅行的重要选择。高铁线路连接了北京、上海、广州、成都、重庆等主要城市,运行时间通常在几小时以内。此外,普通火车线路也覆盖了偏远地区,适合预算有限的游客。火车票可以通过12306官网或火车站售票窗口购买。需要注意的是,节假日和旅游旺季期间,火车票可能较为紧张,建议游客提前预订。

长途汽车

长途汽车是连接中小城市和偏远地区的主要交通方式。中国的长途汽车站通常位于城市中心或郊区,提供前往周边城市和景区的线路。长途汽车价格相对较低,适合预算有限的游客。游客可以通过车站售票窗口或在线旅行社预订车票。需要注意的是,长途汽车运行时间较长,乘坐体验相对较差,适合短途或预算有限的游客。

城市交通

在城市内部,游客可以乘坐地铁、公交车、出租车和共享单车等交通工具。地铁是城市交通的主要方式,覆盖了大部分城市中心区域,运行速度快,票价合理。公交车则适合短途出行,价格便宜,但运行时间较长。出租车和网约车(如滴滴出行)则提供了便捷的出行选择,适合急需出行的游客。共享单车则方便游客在市区内短途骑行,适合喜欢户外活动的游客。

总之,中国的交通网络发达,为游客提供了多样化的出行选择。无论您选择哪种交通方式,都能方便快捷地到达目的地。提前规划行程和预订交通票务,能让您的旅行更加顺利和愉快。

中国的旅游安全和注意事项

在享受中国丰富多样的旅游体验时,游客需要关注一些旅游安全和注意事项,以确保旅行的顺利进行。以下是一些常见的旅游安全和健康问题,以及应对措施:

安全问题

  1. 随身物品安全:在公共场所,如火车站、机场、旅游景点等,游客应时刻注意随身物品的安全。贵重物品应随身携带,避免放在无人看管的地方。此外,游客应避免夜间单独行动,尤其是在偏僻或人烟稀少的地方。

  2. 交通安全:在乘坐公共交通工具时,游客应遵守交通规则,如过马路时要走斑马线,乘坐出租车时要确保司机使用计价器。自驾游的游客应熟悉当地交通规则,遵守交通信号,确保行车安全。

  3. 自然灾害:中国部分地区可能发生自然灾害,如地震、洪水、台风等。游客在旅行前应关注目的地的天气和自然灾害预警,提前做好应对准备。在自然灾害发生时,游客应听从当地政府和导游的安排,避免前往危险区域。

  4. 社会治安:中国整体社会治安良好,但在一些地区,游客仍需注意个人安全。避免夜间单独外出,特别是在人烟稀少的地方。如遇到紧急情况,应及时报警并寻求帮助。

健康问题

  1. 饮食卫生:中国的饮食文化丰富多样,但游客在品尝当地美食时,应特别注意饮食卫生。避免食用生冷食品、街边小吃和不熟悉的食物,以防食物中毒。游客还应携带一些常用的药物,如感冒药、消炎药、创可贴等。

  2. 水土不服:中国地域辽阔,各地水质和气候条件不同,游客在旅行过程中可能会出现水土不服的情况。建议游客提前准备一些常用的药物,如止泻药、消炎药等。此外,游客应多喝水,保持身体水分平衡。

  3. 防晒和防蚊虫:在户外活动时,游客应做好防晒措施,如涂抹防晒霜、戴帽子和太阳镜等。同时,应避免蚊虫叮咬,携带防蚊虫叮咬的药物和用品,如蚊香、防蚊液等。

  4. 医疗保健:游客在旅行前应了解目的地的医疗资源和保健设施,如医院、药店等。在旅行过程中,如出现身体不适,应及时就医。此外,游客应购买旅游保险,以应对可能的医疗费用和紧急情况。

总之,了解和关注旅游安全和健康问题,是确保旅行顺利进行的重要保障。游客在旅行前应做好充分的准备,遵守当地的规定和习俗,确保旅行的安全和愉快。

中国旅游的最佳季节

中国的气候多样,不同地区和季节各有特色,因此选择合适的旅游季节对提升旅行体验至关重要。以下是中国各地旅游的最佳季节及推荐理由:

北方地区

最佳旅游季节:春季(3月至5月)和秋季(9月至11月)

推荐理由

  • 春季:北方地区春季气温适中,万物复苏,绿意盎然。北京、天津等地的樱花盛开,哈尔滨的冰雪节也是春季的好时节,可以欣赏到美丽的冰雕和雪景。
  • 秋季:北方地区秋季气候宜人,秋高气爽,景色宜人。北京、西安等地的红叶季节,景色壮观,是摄影爱好者的最佳选择。此外,秋季的丰收季节,游客还可以品尝到各种新鲜的水果和美食。

中部地区

最佳旅游季节:春季(3月至5月)和秋季(9月至11月)

推荐理由

  • 春季:中部地区春季气温适中,气候宜人,是游览武汉、长沙等城市的最佳时节。此时,武汉的樱花盛开,长沙的岳麓山也是春游的好去处。
  • 秋季:中部地区秋季气候凉爽,秋高气爽,是游览南京、杭州等城市的理想季节。南京的桂花香气四溢,杭州的西湖美景更是令人陶醉。

南方地区

最佳旅游季节:春季(3月至5月)和秋季(9月至11月)

推荐理由

  • 春季:南方地区春季气候温暖,万物复苏,是游览广州、深圳等城市的最佳时节。此时,广州的荔枝花盛开,深圳的华侨城也是春季游玩的好地方。
  • 秋季:南方地区秋季气候宜人,秋高气爽,是游览厦门、福州等城市的理想季节。厦门的鼓浪屿、福州的三坊七巷等景点,秋季景色尤为迷人。

西部地区

最佳旅游季节:夏季(6月至8月)和秋季(9月至11月)

推荐理由

  • 夏季:西部地区夏季气候凉爽,是游览青藏高原、九寨沟等景区的最佳时节。青藏高原的蓝天白云、九寨沟的翠湖碧水,都是夏季的独特美景。
  • 秋季:西部地区秋季气候宜人,秋高气爽,是游览新疆、甘肃等地的理想季节。此时,新疆的瓜果飘香,甘肃的丹霞地貌色彩斑斓,景色壮观。

总之,选择合适的旅游季节,不仅能享受到最佳的自然风光,还能避免旅游高峰期的人流拥挤,提升旅行的舒适度和体验感。游客可以根据自己的兴趣和需求,选择适合的季节前往中国各地旅游。

中国旅游的实用小贴士

为了确保您的中国之旅顺利愉快,以下是一些实用的旅游小贴士,帮助您更好地准备和应对旅行中的各种情况:

旅行前的准备

  1. 签证办理:根据您的国籍和旅行目的,提前办理相应的签证。中国签证办理流程较为复杂,建议提前了解并准备好所有必要的材料。

  2. 行程规划:提前规划您的旅行路线和行程,包括景点、住宿、交通等。这样可以避免在旅途中因不确定因素而浪费时间。

  3. 预订机票和酒店:提前预订机票和酒店,可以享受更优惠的价格。特别是在旅游旺季,提前预订可以确保您的行程顺利。

  4. 购买旅游保险:购买旅游保险,以应对可能的突发情况,如意外伤害、疾病、行李丢失等。

  5. 了解当地文化和习俗:在旅行前,了解目的地的文化和习俗,尊重当地的风俗习惯,避免不必要的冲突。

旅行中的注意事项

  1. 随身携带重要文件:如护照、签证、身份证、机票等,并确保这些文件的安全。

  2. 随身物品:携带一些必要的随身物品,如钱包、手机、相机、充电器、转换插头等。同时,注意保管好贵重物品,避免丢失或被盗。

  3. 交通规则:熟悉目的地的交通规则,特别是在自驾游时,遵守交通信号和规定,确保行车安全。

  4. 饮食卫生:在品尝当地美食时,注意饮食卫生,避免食用生冷食品和不熟悉的食物,以防食物中毒。

  5. 防晒和防蚊虫:在户外活动时,做好防晒和防蚊虫叮咬的措施,携带防晒霜、防蚊液等用品。

应对突发情况

  1. 紧急情况:如遇紧急情况,如丢失证件、生病、意外伤害等,及时报警并联系旅行社或当地朋友寻求帮助。

  2. 医疗保健:了解目的地的医疗资源和保健设施,如医院、药店等。在旅行前准备一些常用的药物,如感冒药、消炎药、创可贴等。

  3. 网络通讯:确保手机和网络畅通,可以购买当地的电话卡或使用国际漫游服务,以便在紧急情况下与家人和朋友联系。

总之,通过充分的准备和注意旅行中的细节,您可以确保中国之旅顺利愉快。这些实用小贴士将帮助您更好地应对旅行中的各种情况,享受美好的旅程。

中国旅游的总结与展望

中国,这片古老而神秘的东方土地,以其丰富的历史、多样的自然风光和独特的文化,吸引了无数游客。从长城的巍峨壮观到故宫的宏伟壮丽,从九寨沟的碧水蓝天到丽江古城的宁静古朴,每一个景点都让人流连忘返。中国的美食文化更是让人垂涎欲滴,无论是北方的面食、南方的米饭,还是西部地区的牛羊肉,每一种美食都蕴含着深厚的文化底蕴。此外,中国的传统文化和节庆活动丰富多彩,从春节的团圆到端午节的龙舟竞渡,每一个节日都让人感受到浓厚的民族情感。

中国旅游的魅力不仅在于其丰富的自然和人文景观,更在于其热情好客的人民。无论您是历史爱好者、自然探险家还是文化体验者,都能在中国找到属于自己的旅游天堂。无论是繁华的都市、宁静的古镇,还是壮丽的自然景观,中国都能满足您的旅游需求。

展望未来,中国将继续加大对旅游基础设施的投入,提升旅游服务质量,为全球游客提供更加便捷和舒适的旅游体验。随着中国旅游业的不断发展,更多的国际游客将来到中国,探索这片古老而神秘的东方土地。

我们诚挚地邀请您来中国旅游,感受这片土地的独特魅力。无论是历史遗迹、自然风光还是美食文化,中国都将为您带来难忘的旅行体验。期待您的到来,一起探索中国的美丽与神奇!

📝 Citation

If you find our work useful, please kindly cite:

@article{bai2024longwriter,
  title={LongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs}, 
  author={Yushi Bai and Jiajie Zhang and Xin Lv and Linzhi Zheng and Siqi Zhu and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li},
  journal={arXiv preprint arXiv:2408.07055},
  year={2024}
}

THUDM/LongWriter/blob/main/agentwrite/plan.py:

```py
import requests
import time, os, json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import random
import codecs
import argparse
from copy import deepcopy
from tqdm import tqdm
import traceback
import re
import torch.distributed as dist
import torch.multiprocessing as mp

GPT4_API_KEY = ''
GPT_MODEL = 'gpt-4o-2024-05-13'
def get_response_gpt4(prompt, max_new_tokens=1024, temperature=1.0, stop=None):
    tries = 0
    while tries < 10:
        tries += 1
        try:
            headers = {
                'Authorization': "Bearer {}".format(GPT4_API_KEY),
            }
            messages = [
                {'role': 'user', 'content': prompt},
            ]
            resp = requests.post("https://api.openai.com/v1/chat/completions", json = {
                "model": GPT_MODEL,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_new_tokens,
                "stop": stop,
            }, headers=headers, timeout=600)
            if resp.status_code != 200:
                raise Exception(resp.text)
            resp = resp.json()
            break
        except KeyboardInterrupt as e:
            raise e
        except Exception as e:
            if "maximum context length" in str(e):
                raise e
            elif "triggering" in str(e):
                return 'Trigger OpenAI\'s content management policy'
            print("Error Occurs: \"%s\"        Retry ..."%(str(e)))
    else:
        print("Max tries. Failed.")
        return "Max tries. Failed."
    try:
        return resp["choices"][0]["message"]["content"]
    except: 
        return ''

def get_pred(rank, world_size, data, max_new_tokens, fout, template):
    for item in tqdm(data):
        prompt = item['prompt']
        prompt = template.replace('$INST$', prompt)
        try:
            response = get_response_gpt4(prompt, max_new_tokens)
            item["plan"] = response
            fout.write(json.dumps(item, ensure_ascii=False)+'\n')
            fout.flush()
        except Exception as e:
            print(e)

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

if __name__ == '__main__':
    # input format: {"prompt": "xxx", ...}
    # output format: {"prompt": "xxx", "plan": "xxx", ...}
    in_file = 'instructions.jsonl'
    out_file = 'plan.jsonl'
    seed_everything(42)
    max_new_tokens = 4096
    world_size = 8
    has_data = {}
    if os.path.exists(out_file):
        with open(out_file, encoding='utf-8') as f:
            has_data = {json.loads(line)["prompt"]: 0 for line in f}
    fout = open(out_file, 'a', encoding='utf-8')
    data = []
    with open(in_file, encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            if item["prompt"] not in has_data:
                data.append(item)
    template = open('prompts/plan.txt', encoding='utf-8').read()

    data_subsets = [data[i::world_size] for i in range(world_size)]
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_new_tokens, fout, template))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

THUDM/LongWriter/blob/main/agentwrite/prompts/plan.txt:

I need you to help me break down the following long-form writing instruction into multiple subtasks. Each subtask will guide the writing of one paragraph in the essay, and should include the main points and word count requirements for that paragraph.

The writing instruction is as follows:

$INST$

Please break it down in the following format, with each subtask taking up one line:

Paragraph 1 - Main Point: [Describe the main point of the paragraph, in detail] - Word Count: [Word count requirement, e.g., 400 words]

Paragraph 2 - Main Point: [Describe the main point of the paragraph, in detail] - Word Count: [word count requirement, e.g. 1000 words].

...

Make sure that each subtask is clear and specific, and that all subtasks cover the entire content of the writing instruction. Do not split the subtasks too finely; each subtask's paragraph should be no less than 200 words and no more than 1000 words. Do not output any other content.

THUDM/LongWriter/blob/main/agentwrite/prompts/write.txt:

You are an excellent writing assistant. I will give you an original writing instruction and my planned writing steps. I will also provide you with the text I have already written. Please help me continue writing the next paragraph based on the writing instruction, writing steps, and the already written text.

Writing instruction:

$INST$

Writing steps:

$PLAN$

Already written text:

$TEXT$

Please integrate the original writing instruction, writing steps, and the already written text, and now continue writing $STEP$. If needed, you can add a small subtitle at the beginning. Remember to only output the paragraph you write, without repeating the already written text.

THUDM/LongWriter/blob/main/agentwrite/write.py:

import requests
import time, os, json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import random
import codecs
import argparse
from copy import deepcopy
from tqdm import tqdm
import traceback
import re
import torch.distributed as dist
import torch.multiprocessing as mp

GPT4_API_KEY = ''
GPT_MODEL = 'gpt-4o-2024-05-13'
def get_response_gpt4(prompt, max_new_tokens=1024, temperature=1.0, stop=None):
    tries = 0
    while tries < 10:
        tries += 1
        try:
            headers = {
                'Authorization': "Bearer {}".format(GPT4_API_KEY),
            }
            messages = [
                {'role': 'user', 'content': prompt},
            ]
            resp = requests.post("https://api.openai.com/v1/chat/completions", json = {
                "model": GPT_MODEL,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_new_tokens,
                "stop": stop,
            }, headers=headers, timeout=600)
            if resp.status_code != 200:
                raise Exception(resp.text)
            resp = resp.json()
            break
        except KeyboardInterrupt as e:
            raise e
        except Exception as e:
            if "maximum context length" in str(e):
                raise e
            elif "triggering" in str(e):
                return 'Trigger OpenAI\'s content management policy'
            print("Error Occurs: \"%s\"        Retry ..."%(str(e)))
    else:
        print("Max tries. Failed.")
        return "Max tries. Failed."
    try:
        return resp["choices"][0]["message"]["content"]
    except: 
        return ''

def get_pred(rank, world_size, data, max_new_tokens, fout, template, cache_fout, cache_dict):
    for item in tqdm(data):
        try:
            inst = item['prompt']
            plan = item['plan'].strip().replace('\n\n', '\n')
            steps = plan.split('\n')
            text = ""
            responses = []
            if len(steps) > 50:
                print(plan)
                continue
            for step in steps:
                if inst in cache_dict and step in cache_dict[inst]:
                    response = cache_dict[inst][step]
                    responses.append(response)
                    text += response + '\n\n'
                    continue
                prompt = template.replace('$INST$', inst).replace('$PLAN$', plan.strip()).replace('$TEXT$', text.strip()).replace('$STEP$', step.strip())
                response = get_response_gpt4(prompt, max_new_tokens)
                if response == '':
                    break
                # save to cache
                cache_fout.write(json.dumps({"prompt": inst, "step": step, "response": response}, ensure_ascii=False)+'\n')
                cache_fout.flush()
                responses.append(response)
                text += response + '\n\n'
            if response == '':
                continue
            item["write"] = responses
            fout.write(json.dumps(item, ensure_ascii=False)+'\n')
            fout.flush()
        except Exception as e:
            print(e)

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

if __name__ == '__main__':
    # input format: {"prompt": "xxx", "plan": "xxx", ...}
    # output format: {"prompt": "xxx", "plan": "xxx", "write": [...], ...}
    in_file = 'plan.jsonl'
    out_file = 'write.jsonl'
    cache_file = 'write_cache.jsonl'
    seed_everything(42)
    max_new_tokens = 4096
    world_size = 8
    has_data = {}
    if os.path.exists(out_file):
        with open(out_file, encoding='utf-8') as f:
            has_data = {json.loads(line)["prompt"]: 0 for line in f}
    cache_dict = {}
    if os.path.exists(cache_file):
        with open(cache_file, encoding='utf-8') as f:
            for line in f:
                item = json.loads(line)
                if item["prompt"] not in cache_dict:
                    cache_dict[item["prompt"]] = {}
                cache_dict[item["prompt"]][item["step"]] = item["response"]
    fout = open(out_file, 'a', encoding='utf-8')
    cache_fout = open(cache_file, 'a', encoding='utf-8')
    data = []
    with open(in_file, encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            if item["prompt"] not in has_data:
                data.append(item)
    template = open('prompts/write.txt', encoding='utf-8').read()

    data_subsets = [data[i::world_size] for i in range(world_size)]
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_new_tokens, fout, template, cache_fout, cache_dict))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

THUDM/LongWriter/blob/main/evaluation/eval_length.py:

import json
import matplotlib.pyplot as plt
import numpy as np

def score(x, y):
    if y > x:
        return 100 * max(0, 1. - (y / x - 1) / 3)
    else:
        return 100 * max(0, 1. - (x / y - 1) / 2)

model = "LongWriter-glm4-9b"
prediction = [json.loads(line) for line in open(f'models/{model}/pred.jsonl', encoding='utf-8')]
x, y, scores = [], [], []
for pred in prediction:
    x.append(pred["length"])
    y.append(pred["response_length"])
    scores.append(score(pred["length"], pred["response_length"]))

print(np.mean(scores))

# set plt size 6x6
plt.figure(figsize=(6, 6))
lmt = 25000
# plot x, y
plt.scatter(x, y, s=100, c='r', alpha=0.3)
# plot x=y
plt.plot([0, lmt], [0, lmt], 'k--')
plt.xscale('log')
plt.yscale('log')
plt.xlim(50, lmt)
plt.ylim(50, lmt)
plt.xlabel('Required Length', fontsize=20, fontweight='bold')
plt.ylabel('Output Length', fontsize=20, fontweight='bold')
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.tight_layout()
plt.savefig(f'models/{model}/scatter.png')

THUDM/LongWriter/blob/main/evaluation/eval_quality.py:

import json
import random
import requests
import multiprocessing
from tqdm import tqdm
import re

dims = ["Relevance", "Accuracy", "Coherence", "Clarity", "Breadth and Depth", "Reading Experience"]
model = "LongWriter-glm4-9b"
filename = f"models/{model}/judge.jsonl"
prediction_file = open(f"models/{model}/pred.jsonl", "r", encoding="utf-8")

prompt_template = open("judge.txt", "r", encoding="utf-8").read()
fout = open(filename, 'w', encoding='utf-8')

GPT4_API_KEY = '' # Your API Key
GPT_MODEL = 'gpt-4o-2024-05-13'
def get_response_gpt4(prompt, temperature=0.5, max_new_tokens=1024, stop=None):
    tries = 0
    while tries < 10:
        tries += 1
        try:
            headers = {
                'Authorization': "Bearer {}".format(GPT4_API_KEY),
            }
            messages = [
                {'role': 'user', 'content': prompt},
            ]
            resp = requests.post("https://api.openai.com/v1/chat/completions", json = {
                "model": GPT_MODEL,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_new_tokens,
                "stop": stop,
            }, headers=headers, timeout=600)
            if resp.status_code != 200:
                raise Exception(resp.text)
            resp = resp.json()
            break
        except KeyboardInterrupt as e:
            raise e
        except Exception as e:
            if "maximum context length" in str(e):
                raise e
            elif "triggering" in str(e):
                return 'Trigger OpenAI\'s content management policy'
            print("Error Occurs: \"%s\"        Retry ..."%(str(e)))
    else:
        print("Max tries. Failed.")
        return "Max tries. Failed."
    try:
        return resp["choices"][0]["message"]["content"]
    except: 
        return ''

def extract_info(pattern, text):
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1)
    else:
        return None

def process_data(items):
    for item in tqdm(items):
        prompt = prompt_template.replace('$INST$', item['prompt']).replace('$RESPONSE$', item["response"])
        scores = None
        trys = 0
        while scores is None and trys < 5:
            output = get_response_gpt4(prompt)
            try:
                if '```json' in output:
                    output = extract_info(r'```json\n(.*?)\n```', output)
                output = output.replace('\n', '')
                scores = json.loads(output)
                for dim in dims:
                    if dim not in scores:
                        scores = None
                        trys += 1
            except Exception as e:
                trys += 1
        if scores is None:
            print(output)
        else:
            item['scores'] = scores
            fout.write(json.dumps(item, ensure_ascii=False)+'\n')
            fout.flush()

data = [json.loads(line) for line in prediction_file]
random.shuffle(data)
PROC_NUM = 8
pool = multiprocessing.Pool(processes=PROC_NUM)
total = len(data)

for i in range(PROC_NUM):
    start = (i * total) // PROC_NUM
    end = None if i == PROC_NUM - 1 else ((i + 1) * total) // PROC_NUM
    pool.apply_async(process_data, args=(data[start:end],))

pool.close()
pool.join()
fout.close()

all_scores = [json.loads(line)['scores'] for line in open(filename, 'r', encoding='utf-8')]

total_score = dict()
for dim in dims:
    scores = [float(score[dim]) if dim in score else 3 for score in all_scores]
    total_score[dim] = ((sum(scores) / len(scores)) - 1) * 25
total_score['total'] = sum(total_score.values()) / len(total_score)
print(total_score)
with open(filename, 'a', encoding='utf-8') as fout:
    fout.write(json.dumps(total_score, ensure_ascii=False)+'\n')

THUDM/LongWriter/blob/main/evaluation/judge.txt:

You are an expert in evaluating text quality. Please evaluate the quality of an AI assistant's response to a user's writing request. Be as strict as possible.

You need to evaluate across the following six dimensions, with scores ranging from 1 to 5. The scoring criteria from 5 to 1 for each dimension are as follows:

1. Relevance: From content highly relevant and fully applicable to the user's request to completely irrelevant or inapplicable.

2. Accuracy: From content completely accurate with no factual errors or misleading information to content with numerous errors and highly misleading.

3. Coherence: From clear structure with smooth logical connections to disorganized structure with no coherence.

4. Clarity: From clear language, rich in detail, and easy to understand to confusing expression with minimal details.

5. Breadth and Depth: From both broad and deep content with a lot of information to seriously lacking breadth and depth with minimal information.

6. Reading Experience: From excellent reading experience, engaging and easy to understand content to very poor reading experience, boring and hard to understand content.

Please evaluate the quality of the following response to a user's request according to the above requirements.

<User Request>

$INST$

</User Request>

<Response>

$RESPONSE$

</Response>

Please evaluate the quality of the response. You must first provide a brief analysis of its quality, then give a comprehensive analysis with scores for each dimension. The output must strictly follow the JSON format: {"Analysis": ..., "Relevance": ..., "Accuracy": ..., "Coherence": ..., "Clarity": ..., "Breadth and Depth": ..., "Reading Experience": ...}. You do not need to consider whether the response meets the user's length requirements in your evaluation. Ensure that only one integer between 1 and 5 is output for each dimension score.

THUDM/LongWriter/blob/main/evaluation/pred.py:

import requests
import time, os, json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import random
import codecs
import argparse
from copy import deepcopy
from tqdm import tqdm
import traceback
import re
import torch.distributed as dist
import torch.multiprocessing as mp

def count_words(text):
    chinese_characters = re.findall(r'[\u4e00-\u9fff]', text)
    english_words = re.findall(r'\b[a-zA-Z]+\b', text)
    
    chinese_char_count = len(chinese_characters)
    english_word_count = len(english_words)
    
    total_count = chinese_char_count + english_word_count
    
    return total_count

def get_pred(rank, world_size, data, path, max_new_tokens, temperature, tokenizer, fout):
    device = torch.device(f'cuda:{rank}')
    model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    model = model.eval()
    for dt in data:
        prompt = dt['prompt']
        if "llama" in path.lower():
            prompt = f"[INST]{prompt}[/INST]"
            input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
            context_length = input.input_ids.shape[-1]
            output = model.generate(
                **input,
                max_new_tokens=max_new_tokens,
                num_beams=1,
                do_sample=True,
                temperature=temperature,
            )[0]
            response = tokenizer.decode(output[context_length:], skip_special_tokens=True)
        else:
            response, history = model.chat(tokenizer, prompt, history=[], max_new_tokens=max_new_tokens, temperature=temperature)
        dt["response_length"] = count_words(response)
        dt["response"] = response
        fout.write(json.dumps(dt, ensure_ascii=False)+'\n')
        fout.flush()
        print(response)

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

if __name__ == '__main__':
    seed_everything(42)
    model = 'LongWriter-glm4-9b' # LongWriter-llama3.1-8b
    path = "THUDM/LongWriter-glm4-9b" # THUDM/LongWriter-llama3.1-8b
    os.makedirs(f"models/{model}", exist_ok=True)
    fout = open(f"models/{model}/pred.jsonl", 'w', encoding='utf-8')

    max_new_tokens = 32768
    temperature = 0.5
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    world_size = torch.cuda.device_count()

    with open('longbench_write.jsonl', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]

    data_subsets = [data[i::world_size] for i in range(world_size)]
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], path, max_new_tokens, temperature, tokenizer, fout))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

THUDM/LongWriter/blob/main/requirements.txt:

tqdm
numpy
torch
transformers==4.33.0
datasets
einops

THUDM/LongWriter/blob/main/train/README.md:

## 🖥️ LongWriter Training

### Data preprocessing

First, tokenize the raw text data using the tokenizer of the model. Run the code `pre_tokenize_glm4.py` for GLM-4-9B or `pre_tokenize_llama3.py` for Llama-3.1-8B. Remember to add your general SFT data path. please format your data as follows: 
```json
{
    "messages": [{"role": "user", "content": "..."}, 
                 {"role": "assistant", "content": "..."}, ...]
    }

We use packing strategy for more efficient training, run

python sort_and_group.py --train_file ./data/glm4/longwriter

to organize the tokenized data for packing training.

Model training

We provide training scripts under scripts/ for the GLM-4-9B and Llama-3.1-8B model series. Make sure to adjust --model_name_or_path, --train_file, and --output_dir to match your model path, data path, and output path.

To support packing training, we provide patch files under patch/, please replace the original modeling files with them.


THUDM/LongWriter/blob/main/train/dataset.py:

```py
import torch
import os
import json
import numpy as np
from torch.utils.data import DataLoader

class LMDataset(torch.utils.data.Dataset):
    def __init__(self, filepath):
        self.input_ids, self.labels = self.process_data(filepath)
        self.input_ids = self.input_ids
        self.labels = self.labels

    def process_data(self, filepath):
        input_ids = torch.from_numpy(np.load(os.path.join(filepath, 'inputs.npy')))
        labels = torch.from_numpy(np.load(os.path.join(filepath, 'labels.npy')))
        return input_ids, labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'labels': self.labels[idx]
        }

    def __len__(self):
        return self.input_ids.size(0)

class LMSortDataset(torch.utils.data.Dataset):
    def __init__(self, filepath):
        self.input_ids, self.labels = self.process_data(filepath)
        self.input_ids = self.input_ids
        self.labels = self.labels
    
    def process_data(self, filepath):
        input_ids = torch.from_numpy(np.load(os.path.join(filepath, 'inputs_sort.npy')))
        labels = torch.from_numpy(np.load(os.path.join(filepath, 'labels_sort.npy')))
        return input_ids, labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'labels': self.labels[idx]
        }

    def __len__(self):
        return self.input_ids.size(0)

class LMPackDataset(torch.utils.data.Dataset):
    def __init__(self, filepath):
        self.input_ids, self.attention_masks, self.labels, self.weights, self.nums = self.process_data(filepath)
        self.num_gpus = torch.cuda.device_count()
        
    def process_data(self, filepath):
        input_ids = torch.from_numpy(np.load(os.path.join(filepath, 'inputs_pack.npy')))
        labels = torch.from_numpy(np.load(os.path.join(filepath, 'labels_pack.npy')))
        weights = torch.from_numpy(np.load(os.path.join(filepath, 'weights_pack.npy')))
        attention_masks = json.load(open(os.path.join(filepath, 'attention_masks_pack.json')))
        num_gpus = torch.cuda.device_count()
        l = (input_ids.size(0) // num_gpus) * num_gpus
        input_ids, labels, weights, attention_masks = input_ids[:l, :], labels[:l, :], weights[:l, :], attention_masks[:l]
        nums = [weights[i*num_gpus:(i+1)*num_gpus, :].sum() for i in range(l//num_gpus)]
        return input_ids, attention_masks, labels, weights, nums

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': torch.tensor(self.attention_masks[idx], dtype=torch.int32),
            'labels': (self.labels[idx], self.weights[idx], self.nums[idx//self.num_gpus])
        }

    def __len__(self):
        return self.input_ids.size(0)

THUDM/LongWriter/blob/main/train/main.py:

import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence, List
import os
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from dataset import LMDataset, LMSortDataset, LMPackDataset
from trainer import TrainerNoShuffle

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="THUDM/glm-4-9b")
    pack_loss: bool = field(default=False)

@dataclass
class DataArguments:
    train_file: str = field(default=None, metadata={"help": "Path to the training data."})
    validation_file: str = field(default=None, metadata={"help": "Path to the training data."})
    preprocessing_num_workers: Optional[int] = field(
        default=1,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    prompt_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    response_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
    batch_method: str = field(default="naive")

@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")

@dataclass
class DataCollatorForLMDataset(object):

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.cat(input_ids, dim=0)
        labels = torch.cat(labels, dim=0)
        eos_indices = input_ids.argmin(dim=1) - 1
        max_position = eos_indices.max()
        if max_position < 0:
            return dict(
                input_ids=input_ids,
                labels=labels
            )
        return dict(
            input_ids=input_ids[:, :max_position+1],
            labels=labels[:, :max_position+1]
        )

@dataclass
class DataCollatorForLMPackDataset(object):

    def __call__(self, instances):
        input_ids, attention_masks = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ["input_ids", "attention_mask"])
        batch_seq_num = instances[0]["labels"][2]
        labels = ([instance["labels"][0].unsqueeze(0) for instance in instances], [instance["labels"][1].unsqueeze(0) for instance in instances])
        input_ids = torch.cat(input_ids, dim=0)
        labels = (torch.cat(labels[0], dim=0), torch.cat(labels[1], dim=0))
        labels = (labels[0], labels[1].sum()/30)
        max_length = input_ids.shape[1]
        attention_mask = attention_masks[0].squeeze()
        acc_length = max_length
        for new_attention_mask in attention_masks[1:]:
            new_attention_mask = new_attention_mask.squeeze()
            attention_mask = torch.cat([attention_mask, new_attention_mask[1:]+acc_length], dim=0)
            acc_length += max_length
        return dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

def make_supervised_data_module(data_args) -> Dict:
    print("loading data...")
    if data_args.batch_method == "naive":
        train_dataset = LMDataset(data_args.train_file)
        data_collator = DataCollatorForLMDataset()
    elif data_args.batch_method == "pack":
        train_dataset = LMPackDataset(data_args.train_file)
        data_collator = DataCollatorForLMPackDataset()
    elif data_args.batch_method == "sort":
        train_dataset = LMSortDataset(data_args.train_file)
        data_collator = DataCollatorForLMDataset()
    print("finish loading data")
    return dict(train_dataset=train_dataset, data_collator=data_collator)

def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    if "chatglm" in model_args.model_name_or_path.lower() or "longalign-6b" in model_args.model_name_or_path.lower():
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True, empty_init=False
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=True
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, 
                                          torch_dtype=torch.bfloat16, 
                                          trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
                                                  trust_remote_code=True)
    if model_args.pack_loss:
        model.pack_loss = True
    data_module = make_supervised_data_module(data_args=data_args)

    trainer = TrainerNoShuffle(
        model=model, 
        tokenizer=tokenizer, 
        args=training_args, 
        **data_module
    )

    trainer.train(resume_from_checkpoint=False)
    trainer.save_model()

if __name__ == "__main__":
    train()

THUDM/LongWriter/blob/main/train/patch/modeling_chatglm.py:

""" PyTorch ChatGLM model. """

import math
import copy
import warnings
import re
import sys

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput

from .configuration_chatglm import ChatGLMConfig
from einops import rearrange
try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    try:
        # FlashAttention-2
        from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
    except ImportError:
        flash_attn_unpadded_func = None

# flags required to enable jit fusion kernels

if sys.platform != 'darwin':
    torch._C._jit_set_profiling_mode(False)
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_override_can_fuse_on_cpu(True)
    torch._C._jit_override_can_fuse_on_gpu(True)

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
_CONFIG_FOR_DOC = "ChatGLM6BConfig"

CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "THUDM/chatglm2-6b",
    # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
]

def default_init(cls, *args, **kwargs):
    return cls(*args, **kwargs)


class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores

def split_tensor_along_last_dim(
        tensor: torch.Tensor,
        num_partitions: int,
        contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """Split a tensor along its last dimension.

    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.

    Returns:
        A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = tensor.size()[last_dim] // num_partitions
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl
        self.rope_ratio = rope_ratio

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        
        base = base * self.rope_ratio
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


class CoreAttention(torch.nn.Module):
    def __init__(self, config: ChatGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.attention_dropout = config.attention_dropout

    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1]
        seqlen_k = key_layer.shape[0]
        query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> (b s) ...') for x in [query_layer, key_layer, value_layer]]
        # DO flash_attn_varlen_func
        if attention_mask is None or attention_mask.ndim != 1:
            cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
                                    device=query_layer.device)
        else:
            assert seqlen_q == seqlen_k
            cu_seqlens_q = attention_mask
        if self.training:
            assert seqlen_k == seqlen_q
            is_causal = True
            cu_seqlens_k = cu_seqlens_q
        else:
            is_causal = seqlen_q == seqlen_k
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
                        device=query_layer.device) if not is_causal else cu_seqlens_q
            self.attention_dropout = 0
        context_layer = flash_attn_unpadded_func(
            query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
            self.attention_dropout,
            softmax_scale=1.0 / self.norm_factor, causal=is_causal
        )
        context_layer = rearrange(context_layer, '(b s) ... -> s b ...', b=batch_size)
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
        context_layer = context_layer.reshape(*new_context_layer_shape)
        return context_layer


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            self.qkv_hidden_size = (
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device, **_config_to_kwargs(config)
                               )

    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
        if self.multi_query_attention:
            num_attention_heads = self.num_multi_query_groups_per_partition
        else:
            num_attention_heads = self.num_attention_heads_per_partition
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            num_attention_heads,
            self.hidden_size_per_attention_head,
            dtype=dtype,
            device=device,
        )

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                               (self.num_attention_heads_per_partition,
                                3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # adjust key and value for inference
        if use_cache:
            if kv_cache is not None:
                cache_k, cache_v = kv_cache
                key_layer = torch.cat((cache_k, key_layer), dim=0)
                value_layer = torch.cat((cache_v, value_layer), dim=0)
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None
        
            
        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(-2)
            key_layer = key_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.unsqueeze(-2)
            value_layer = value_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache


def _config_to_kwargs(args):
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    return common_kwargs


class MLP(torch.nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config: ChatGLMConfig, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(torch.nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(GLMBlock, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm

        self.fp32_residual_connection = config.fp32_residual_connection

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                             dtype=config.torch_dtype)

        # Self attention.
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                      dtype=config.torch_dtype)

        # MLP
        self.mlp = MLP(config, device=device)

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
    ):
        # hidden_states: [s, b, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output

        return output, kv_cache


class GLMTransformer(torch.nn.Module):
    """Transformer class."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(GLMTransformer, self).__init__()

        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        def build_layer(layer_number):
            return GLMBlock(config, layer_number, device=device)

        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                 dtype=config.torch_dtype)

        self.gradient_checkpointing = False

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
    ):
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        presents = () if use_cache else None
        if self.gradient_checkpointing and self.training:
            if use_cache:
                # logger.warning_once(
                #     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                # )
                use_cache = False

        all_self_attentions = None
        all_hidden_states = () if output_hidden_states else None
        for index in range(self.num_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training:
                layer_ret = torch.utils.checkpoint.checkpoint(
                    layer,
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_caches[index],
                    use_cache,
                    use_reentrant=False
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache
                )
            hidden_states, kv_cache = layer_ret
            if use_cache:
                presents = presents + (kv_cache,)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states, presents, all_hidden_states, all_self_attentions


class ChatGLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    """

    is_parallelizable = False
    supports_gradient_checkpointing = True
    config_class = ChatGLMConfig
    base_model_prefix = "transformer"
    _no_split_modules = ["GLMBlock"]

    def _init_weights(self, module: nn.Module):
        """Initialize the weights."""
        return

    def get_masks(self, input_ids, past_key_values, padding_mask=None):
        batch_size, seq_length = input_ids.shape
        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
        full_attention_mask.tril_()
        past_length = 0
        if past_key_values:
            past_length = past_key_values[0][0].shape[0]
        if past_length:
            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
                                                        device=input_ids.device), full_attention_mask), dim=-1)
        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1
        full_attention_mask = (full_attention_mask < 0.5).bool()
        full_attention_mask.unsqueeze_(1)
        return full_attention_mask

    def get_position_ids(self, input_ids, device):
        batch_size, seq_length = input_ids.shape
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
        return position_ids

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, GLMTransformer):
            module.gradient_checkpointing = value


class Embedding(torch.nn.Module):
    """Language model embeddings."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(Embedding, self).__init__()

        self.hidden_size = config.hidden_size
        # Word embeddings (parallel).
        self.word_embeddings = nn.Embedding(
            config.padded_vocab_size,
            self.hidden_size,
            dtype=config.torch_dtype,
            device=device
        )
        self.fp32_residual_connection = config.fp32_residual_connection

    def forward(self, input_ids):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
        embeddings = embeddings.transpose(0, 1).contiguous()
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings


class ChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        self.embedding = init_method(Embedding, config, **init_kwargs)

        # Rotary positional embeddings
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )

        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope, 
                                              device=device, dtype=config.torch_dtype)
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = input_ids.shape

        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        # if full_attention_mask is None:
        #     if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
        #         full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

        # Rotary positional embeddings
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        # Run encoder.
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, attention_mask, rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        self.pack_loss = False

    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            is_encoder_decoder: bool = False,
            standardize_cache_format: bool = False,
    ) -> Dict[str, Any]:
        # update past_key_values
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )

        # update attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        # update position ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            new_position_id = position_ids[..., -1:].clone()
            new_position_id += 1
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_position_id], dim=-1
            )

        model_kwargs["is_first_forward"] = False
        return model_kwargs

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            is_first_forward: bool = True,
            **kwargs
    ) -> dict:
        # only last token for input_ids if past is not None
        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)
        if not is_first_forward:
            position_ids = position_ids[..., -1:]
            input_ids = input_ids[:, -1:]
        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "return_last_logit": True
        }

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[Tuple[torch.Tensor]] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        lm_logits = self.transformer.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            if isinstance(labels, tuple) or isinstance(labels, list):
                labels, weights = labels
            shift_labels = labels[..., 1:].contiguous()
            if self.pack_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)#, reduction='none')
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                loss *= weights
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.

        Output shares the same memory storage as `past`.
        """
        return tuple(
            (
                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
            )
            for layer_past in past
        )

    def process_response(self, response):
        response = response.strip()
        response = response.replace("[[训练时间]]", "2023年")
        return response

    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        inputs = inputs.to(self.device)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        response = tokenizer.decode(outputs)
        history.append({"role": role, "content": query})
        response = self.process_response(response)
        return response, history

    def ppl(self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[Tuple[torch.Tensor]] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        lm_logits = self.transformer.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        return loss

THUDM/LongWriter/blob/main/train/patch/modeling_llama.py:

# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
from einops import rearrange
try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    try:
        # FlashAttention-2
        from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
    except ImportError:
        flash_attn_unpadded_func = None

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)


class LlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,
        max_position_embeddings=2048,
        base=10000,
        device=None,
        scaling_factor=1.0,
        rope_type="default",
        config: Optional[LlamaConfig] = None,
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        self.rope_kwargs = {}
        if config is None:
            logger.warning_once(
                "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.45"
            )
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings
            self.original_max_seq_len = max_position_embeddings
        else:
            # BC: "rope_type" was originally "type"
            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling["rope_type"])
            else:
                self.rope_type = "default"
            self.max_seq_len_cached = config.max_position_embeddings
            self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, *args, **kwargs):
        logger.warning_once(
            "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
            "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
        )
        kwargs["rope_type"] = "linear"
        super().__init__(*args, **kwargs)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, *args, **kwargs):
        logger.warning_once(
            "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
            "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
            "__init__)."
        )
        kwargs["rope_type"] = "dynamic"
        super().__init__(*args, **kwargs)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)

        # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class LlamaFlashAttention2(LlamaAttention):
    """
    Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        dropout_rate = self.attention_dropout if self.training else 0.0

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
        )

        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class LlamaSdpaAttention(LlamaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


class LlamaLongAttention(LlamaAttention):

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel

        attention_mask: [bsz, q_len]
        """
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        seqlen_q, batch_size = query_states.shape[2], query_states.shape[0]
        seqlen_k = key_states.shape[2]
        query_states, key_states, value_states = [rearrange(x, 'b h s d -> (b s) h d') for x in [query_states, key_states, value_states]]
        
        # DO flash_attn_varlen_func
        if attention_mask is None or attention_mask.ndim != 1:
            cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
                                    device=query_states.device)
        else:
            assert seqlen_q == seqlen_k
            cu_seqlens_q = attention_mask

        if self.training:
            assert seqlen_k == seqlen_q
            is_causal = True
            cu_seqlens_k = cu_seqlens_q
        else:
            is_causal = seqlen_q == seqlen_k
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
                        device=query_states.device) if not is_causal else cu_seqlens_q
        output = flash_attn_unpadded_func(
            query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, causal=is_causal
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        attn_output = self.o_proj(rearrange(output, 'b s h d -> b s (h d)'))

        return attn_output, None, past_key_value


LLAMA_ATTENTION_CLASSES = {
    "eager": LlamaAttention,
    "flash_attention_2": LlamaFlashAttention2,
    "sdpa": LlamaSdpaAttention,
    "longwriter": LlamaLongAttention,
}


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        config._attn_implementation = "longwriter"
        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        # print(config._attn_implementation)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        """
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


LLAMA_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LlamaConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


LLAMA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance;
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
            )

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if attention_mask is None or attention_mask.ndim != 1:
            return None
        return attention_mask

    # def _update_causal_mask(
    #     self,
    #     attention_mask: torch.Tensor,
    #     input_tensor: torch.Tensor,
    #     cache_position: torch.Tensor,
    #     past_key_values: Cache,
    #     output_attentions: bool,
    # ):
    #     # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
    #     # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
    #     # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
    #     # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

    #     if self.config._attn_implementation == "flash_attention_2":
    #         if attention_mask is not None and 0.0 in attention_mask:
    #             return attention_mask
    #         return None

    #     # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
    #     # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
    #     # to infer the attention mask.
    #     past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
    #     using_static_cache = isinstance(past_key_values, StaticCache)

    #     # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
    #     if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
    #         if AttentionMaskConverter._ignore_causal_mask_sdpa(
    #             attention_mask,
    #             inputs_embeds=input_tensor,
    #             past_key_values_length=past_seen_tokens,
    #             is_training=self.training,
    #         ):
    #             return None

    #     dtype, device = input_tensor.dtype, input_tensor.device
    #     min_dtype = torch.finfo(dtype).min
    #     sequence_length = input_tensor.shape[1]
    #     if using_static_cache:
    #         target_length = past_key_values.get_max_length()
    #     else:
    #         target_length = (
    #             attention_mask.shape[-1]
    #             if isinstance(attention_mask, torch.Tensor)
    #             else past_seen_tokens + sequence_length + 1
    #         )

    #     if attention_mask is not None and attention_mask.dim() == 4:
    #         # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
    #         if attention_mask.max() != 0:
    #             raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
    #         causal_mask = attention_mask
    #     else:
    #         causal_mask = torch.full(
    #             (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
    #         )
    #         if sequence_length != 1:
    #             causal_mask = torch.triu(causal_mask, diagonal=1)
    #         causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
    #         causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
    #         if attention_mask is not None:
    #             causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
    #             mask_length = attention_mask.shape[-1]
    #             padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
    #             padding_mask = padding_mask == 0
    #             causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
    #                 padding_mask, min_dtype
    #             )
    #     if (
    #         self.config._attn_implementation == "sdpa"
    #         and attention_mask is not None
    #         and attention_mask.device.type == "cuda"
    #         and not output_attentions
    #     ):
    #         # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
    #         # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
    #         # Details: https://github.com/pytorch/pytorch/issues/110213
    #         causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

    #     return causal_mask


class LlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.pack_loss = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            if isinstance(labels, tuple) or isinstance(labels, list):
                labels, weights = labels
            shift_labels = labels[..., 1:].contiguous().view(-1).to(shift_logits.device)
            if self.pack_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                loss *= weights
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def ppl(self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[Tuple[torch.Tensor]] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)
        lm_logits = lm_logits.float()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = labels[..., 1:].contiguous()
            
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        return loss

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        **kwargs,
    ):
        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
        # Exception 1: when passing input_embeds, input_ids may be missing entries
        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
        if past_key_values is not None:
            if inputs_embeds is not None:  # Exception 1
                input_ids = input_ids[:, -cache_position.shape[0] :]
            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
                input_ids = input_ids[:, cache_position]

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and cache_position[0] == 0:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

@add_start_docstrings(
    """
    The LLaMa Model transformer with a sequence classification head on top (linear layer).

    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
                sequence_lengths = sequence_lengths % input_ids.shape[-1]
                sequence_lengths = sequence_lengths.to(logits.device)
            else:
                sequence_lengths = -1

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(pooled_logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(pooled_logits, labels)
        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


@add_start_docstrings(
    """
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    base_model_prefix = "transformer"

    # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
    def __init__(self, config):
        super().__init__(config)
        self.transformer = LlamaModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.transformer.embed_tokens

    def set_input_embeddings(self, value):
        self.transformer.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1).to(start_logits.device)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1).to(end_logits.device)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
    output) e.g. for Named-Entity-Recognition (NER) tasks.
    """,
    LLAMA_START_DOCSTRING,
)
class LlamaForTokenClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.score = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.score(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

THUDM/LongWriter/blob/main/train/pre_tokenize_glm4.py:

from transformers import AutoTokenizer, AutoModel, LlamaTokenizer
import copy
import torch
import json, os, random
import multiprocessing
from tqdm import tqdm
import traceback
import numpy as np
import argparse

tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b", trust_remote_code=True)
max_length = 32768
PAD_ID = tokenizer.get_command("[MASK]")
EOS_ID = tokenizer.eos_token_id
skip_exceed_length_case = True
truncate_side = 'right'

PROC_NUM = 64
save_dir = 'multiprocess_data'

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default="chatglm", type=str)
    return parser.parse_args(args)

def process_file(lines, rank, args):
    def build_input(conversations, tokenizer, args):
        zero_width_chars = ["\u200b", "\u200c", "\u200d", "\ufeff"] # filter null characters
        for conv in conversations:
            if conv['role'] == "assistant":
                for char in zero_width_chars:
                    conv['content'] = conv['content'].replace(char, '')

        if len(conversations) == 0:
            return None

        input_ids = []
        starts = []
        ends = []
        for item in conversations:
            content = item["content"]
            role = item["role"]
            if role == 'assistant' and content != '':
                starts.append(len(input_ids))
            input_ids.extend(tokenizer.build_single_message(role, item.get("metadata", ""), content))
            if role == 'assistant' and content != '':
                ends.append(len(input_ids))
        input_ids.append(EOS_ID)
        input_ids = tokenizer.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
        inputs = input_ids.input_ids[0]
        labels = torch.full_like(inputs, -100)
        for start, end in zip(starts, ends):
            labels[start+3:end+3] = inputs[start+3:end+3]

        if inputs.shape[0] > max_length:
            print("exceed_length")
            if skip_exceed_length_case:
                return None
            if truncate_side == 'right':
                inputs = inputs[:max_length]
                labels = labels[:max_length]
            elif truncate_side == 'left':
                cut_num = inputs.shape[0] - max_length
                inputs = torch.cat([inputs[:2], inputs[2 + cut_num:]], dim=0)
                labels = torch.cat([labels[:2], labels[2 + cut_num:]], dim=0)
            else:
                raise ValueError('truncate_side must be "right" or "left"')
        return inputs, labels

    try:
        final_inputs = torch.full((len(lines), max_length), PAD_ID, dtype=torch.int64)
        final_labels = torch.full((len(lines), max_length), -100, dtype=torch.int64)
        pass_data_num = 0

        for line in tqdm(lines):
            conversations = json.loads(line)['messages']
            tmp = build_input(conversations, tokenizer, args)
            if tmp is None:
                continue
            inputs, labels = tmp
            final_inputs[pass_data_num, :inputs.shape[0]] = inputs
            final_labels[pass_data_num, :labels.shape[0]] = labels
            pass_data_num += 1
        final_inputs = final_inputs[:pass_data_num]
        final_labels = final_labels[:pass_data_num]
        torch.save(final_inputs, os.path.join(save_dir, f'inputs_{rank}.pt'))
        torch.save(final_labels, os.path.join(save_dir, f'labels_{rank}.pt'))
    except Exception:
        with open('error.txt', 'a') as f:
            traceback.print_exc(file=f)

def main(args):
    final_dir = f'data/glm4/longwriter'
    os.system('rm -r {}'.format(save_dir))
    os.makedirs(final_dir, exist_ok=True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    manager = multiprocessing.Manager()
    lines = manager.list()

    general = open('general.jsonl', encoding='utf-8').readlines() # TODO: your general sft data
    longwriter = open('LongWriter-6k.jsonl', encoding='utf-8').readlines()
    lines = general + longwriter
    random.shuffle(lines)
    total_lines = len(lines)
    print(total_lines)

    pool = multiprocessing.Pool(processes=PROC_NUM)
    lines_per_process = total_lines // PROC_NUM

    for i in range(PROC_NUM):
        start_line = i * lines_per_process
        end_line = None if i == PROC_NUM - 1 else (i + 1) * lines_per_process
        pool.apply_async(process_file, args=(lines[start_line:end_line], i, args))

    pool.close()
    pool.join()

    inputs, labels = [], []
    for i in tqdm(range(PROC_NUM)):
        inputs.append(torch.load(os.path.join(save_dir, f'inputs_{i}.pt')))
        labels.append(torch.load(os.path.join(save_dir, f'labels_{i}.pt')))
    inputs = torch.cat(inputs, dim=0)
    labels = torch.cat(labels, dim=0)

    input_ids = inputs.numpy().astype(np.int64)
    labels = labels.numpy().astype(np.int64)
    filtered_rows = np.where(~np.all(labels == -100, axis=1))[0]
    input_ids = input_ids[filtered_rows]
    labels = labels[filtered_rows]

    print(labels.shape)
    np.save(os.path.join(final_dir, 'inputs.npy'), input_ids)
    np.save(os.path.join(final_dir, 'labels.npy'), labels)

if __name__ == '__main__':
    main(parse_args())

THUDM/LongWriter/blob/main/train/pre_tokenize_llama3.py:

from transformers import AutoTokenizer, AutoModel, LlamaTokenizer
import copy
import torch
import json, os, random
import multiprocessing
from tqdm import tqdm
import traceback
import numpy as np
import argparse

tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-llama3.1-8b", trust_remote_code=True, use_fast=False)
max_length = 32768
PAD_ID = 128004
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
skip_exceed_length_case = True
truncate_side = 'right'

PROC_NUM = 64
save_dir = 'multiprocess_data'

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default="llama", type=str)
    return parser.parse_args(args)

def process_file(lines, rank, args):
    def build_input(conversations, tokenizer, args):
        zero_width_chars = ["\u200b", "\u200c", "\u200d", "\ufeff"] # filter null characters
        for conv in conversations:
            if conv['role'] == "assistant":
                for char in zero_width_chars:
                    conv['content'] = conv['content'].replace(char, '')

        if len(conversations) == 0:
            return None

        inputs = torch.full((1,), BOS_ID, dtype=torch.int64)
        starts = []
        ends = []
        for item in conversations:
            content = item["content"]
            role = item["role"]
            if role == 'system':
                cur_inputs = tokenizer(f"<<SYS>>\n{content}\n<</SYS>>\n\n", return_tensors="pt")['input_ids'][0]
            elif role == "user":
                cur_inputs = tokenizer(f"[INST]{content}[/INST]", return_tensors="pt")['input_ids'][0]
            else:
                starts.append(inputs.shape[0])
                cur_inputs = tokenizer(content, return_tensors="pt")['input_ids'][0]
                ends.append(inputs.shape[0] + cur_inputs.shape[0])
            inputs = torch.cat([inputs, cur_inputs], dim=0)
            
        inputs = torch.cat([inputs, torch.tensor([EOS_ID])], dim=0)
        labels = torch.full_like(inputs, -100)
        for start, end in zip(starts, ends):
            labels[start:end] = inputs[start:end]
            labels[end] = EOS_ID

        if inputs.shape[0] > max_length:
            print("exceed_length")
            if skip_exceed_length_case:
                return None
            if truncate_side == 'right':
                inputs = inputs[:max_length]
                labels = labels[:max_length]
            elif truncate_side == 'left':
                cut_num = inputs.shape[0] - max_length
                inputs = torch.cat([inputs[:2], inputs[2 + cut_num:]], dim=0)
                labels = torch.cat([labels[:2], labels[2 + cut_num:]], dim=0)
            else:
                raise ValueError('truncate_side must be "right" or "left"')
        return inputs, labels

    try:
        final_inputs = torch.full((len(lines), max_length), PAD_ID, dtype=torch.int64)
        final_labels = torch.full((len(lines), max_length), -100, dtype=torch.int64)
        pass_data_num = 0

        for line in tqdm(lines):
            conversations = json.loads(line)['messages']
            tmp = build_input(conversations, tokenizer, args)
            if tmp is None:
                continue
            inputs, labels = tmp
            final_inputs[pass_data_num, :inputs.shape[0]] = inputs
            final_labels[pass_data_num, :labels.shape[0]] = labels
            pass_data_num += 1
        final_inputs = final_inputs[:pass_data_num]
        final_labels = final_labels[:pass_data_num]
        torch.save(final_inputs, os.path.join(save_dir, f'inputs_{rank}.pt'))
        torch.save(final_labels, os.path.join(save_dir, f'labels_{rank}.pt'))
    except Exception:
        with open('error.txt', 'a') as f:
            traceback.print_exc(file=f)

def main(args):
    final_dir = f'data/llama3/longwriter'
    os.system('rm -r {}'.format(save_dir))
    os.makedirs(final_dir, exist_ok=True)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    manager = multiprocessing.Manager()
    lines = manager.list()

    general = open('general.jsonl', encoding='utf-8').readlines() # TODO: your general sft data
    longwriter = open('LongWriter-6k.jsonl', encoding='utf-8').readlines()
    lines = general + longwriter
    random.shuffle(lines)
    total_lines = len(lines)
    print(total_lines)

    pool = multiprocessing.Pool(processes=PROC_NUM)
    lines_per_process = total_lines // PROC_NUM

    for i in range(PROC_NUM):
        start_line = i * lines_per_process
        end_line = None if i == PROC_NUM - 1 else (i + 1) * lines_per_process
        pool.apply_async(process_file, args=(lines[start_line:end_line], i, args))

    pool.close()
    pool.join()

    inputs, labels = [], []
    for i in tqdm(range(PROC_NUM)):
        inputs.append(torch.load(os.path.join(save_dir, f'inputs_{i}.pt')))
        labels.append(torch.load(os.path.join(save_dir, f'labels_{i}.pt')))
    inputs = torch.cat(inputs, dim=0)
    labels = torch.cat(labels, dim=0)

    input_ids = inputs.numpy().astype(np.int64)
    labels = labels.numpy().astype(np.int64)
    filtered_rows = np.where(~np.all(labels == -100, axis=1))[0]
    input_ids = input_ids[filtered_rows]
    labels = labels[filtered_rows]

    print(labels.shape)
    np.save(os.path.join(final_dir, 'inputs.npy'), input_ids)
    np.save(os.path.join(final_dir, 'labels.npy'), labels)

if __name__ == '__main__':
    main(parse_args())

THUDM/LongWriter/blob/main/train/sort_and_group.py:

from transformers import AutoTokenizer, AutoModel
import copy
import torch
import json, os
import multiprocessing
from tqdm import tqdm
import numpy as np
import random
import argparse

max_length = 32768

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--group_size', default=8, type=int)
    parser.add_argument('--train_file', type=str)
    return parser.parse_args(args)

def main(args):
    filepath = args.train_file
    group_size = args.group_size
    if "llama" in filepath.lower():
        PAD_ID = 128004
        EOS_ID = 128001
    else:
        PAD_ID = 151330
        EOS_ID = 151329

    # pack
    print("load pack")
    input_ids = torch.from_numpy(np.load(os.path.join(filepath, 'inputs.npy')))
    labels = torch.from_numpy(np.load(os.path.join(filepath, 'labels.npy')))
    input_ids = input_ids[:, :max_length]
    labels = labels[:, :max_length]
    num, _ = input_ids.shape
    new_inputs = []
    new_labels = []
    new_weights = []
    attention_masks = []
    tmp_input = torch.full((max_length,), PAD_ID, dtype=torch.int64)
    tmp_label = torch.full((max_length,), -100, dtype=torch.int64)
    tmp_weight = torch.full((max_length,), 0., dtype=torch.float32)
    attention_mask = [0]
    curr_idx = 0
    idx = 0
    total_len = []
    while idx < num:
        print(idx, num)
        input_id, label = input_ids[idx], labels[idx]
        eos_indice = (input_id == EOS_ID).int().argmax().item()
        eos_indice = max_length-1 if eos_indice == 0 else eos_indice
        if curr_idx + eos_indice + 1 > max_length: # full, start new pack
            total_len.append(len(attention_mask))
            new_inputs.append(tmp_input)
            new_labels.append(tmp_label)
            new_weights.append(tmp_weight)
            attention_masks.append(attention_mask+[max_length])
            curr_idx = 0
            tmp_input = torch.full((max_length,), PAD_ID, dtype=torch.int64)
            tmp_label = torch.full((max_length,), -100, dtype=torch.int64)
            tmp_weight = torch.full((max_length,), 0., dtype=torch.float32)
            attention_mask = [0]
        else: # pack in
            tmp_input[curr_idx: curr_idx+eos_indice+1] = input_id[:eos_indice+1]
            tmp_label[curr_idx: curr_idx+eos_indice+1] = label[:eos_indice+1]
            weight = torch.where(label[:eos_indice+1] == -100, 0, 1)
            if weight.sum() > 0.5:
                weight = weight / weight.sum()
            tmp_weight[curr_idx: curr_idx+eos_indice+1] = weight
            curr_idx += (eos_indice+1)
            attention_mask.append(curr_idx)
            idx += 1
    input_ids = torch.stack(new_inputs, dim=0)
    labels = torch.stack(new_labels, dim=0)
    weights = torch.stack(new_weights, dim=0)

    np.save(os.path.join(filepath, 'inputs_pack.npy'), input_ids.numpy().astype(np.int64))
    np.save(os.path.join(filepath, 'labels_pack.npy'), labels.numpy().astype(np.int64))
    np.save(os.path.join(filepath, 'weights_pack.npy'), weights.numpy())
    json.dump(attention_masks, open(os.path.join(filepath, 'attention_masks_pack.json'), 'w'))
    print(np.mean(total_len))

if __name__ == '__main__':
    main(parse_args())

THUDM/LongWriter/blob/main/train/trainer.py:

from transformers import Trainer

import contextlib
import functools
import glob
import math
import os
import random
import re
import sys
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import is_sagemaker_mp_enabled

class TrainerNoShuffle(Trainer):
    def __init__(
        self,
        model = None,
        args: TrainingArguments = None,
        data_collator: Optional["DataCollator"] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
        model_init: Callable[[], "PreTrainedModel"] = None,
        compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
        callbacks: Optional[List["TrainerCallback"]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
        )

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # disable shuffling
        return SequentialSampler(self.train_dataset)

THUDM/LongWriter/blob/main/trans_web_demo.py:

"""
This script creates an interactive web demo for the GLM-4-9B model using Gradio,
a Python library for building quick and easy UI components for machine learning models.
It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
allowing users to interact with the model through a chat-like interface.
"""

import os
from pathlib import Path
from threading import Thread
from typing import Union

import gradio as gr
import torch
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/LongWriter-glm4-9b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)


def _resolve_path(path: Union[str, Path]) -> Path:
    return Path(path).expanduser().resolve()


def load_model_and_tokenizer(
        model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
    model_dir = _resolve_path(model_dir)
    if (model_dir / 'adapter_config.json').exists():
        model = AutoPeftModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model.peft_config['default'].base_model_name_or_path
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=trust_remote_code, device_map='auto'
        )
        tokenizer_dir = model_dir
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
    )
    return model, tokenizer


model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # stop_ids = model.config.eos_token_id
        stop_ids = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def predict(history, prompt, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    if prompt:
        messages.append({"role": "system", "content": prompt})
    for idx, (user_msg, model_msg) in enumerate(history):
        if prompt and idx == 0:
            continue
        if idx == len(history) - 1 and not model_msg:
            # messages.append({"role": "user", "content": user_msg})
            query = user_msg
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True, skip_special_tokens=True)
    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1,
        "eos_token_id": eos_token_id,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    for new_token in streamer:
        if new_token == '<|user|>':
            continue
        elif new_token:
            history[-1][1] += new_token
        yield history


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">LongWriter Chat Demo</h1>""")
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=5, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
            pBtn = gr.Button("Set Prompt")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=32768, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    def set_prompt(prompt_text):
        return [[parse_text(prompt_text), "成功设置prompt"]]


    pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)

    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)

demo.queue()
demo.launch(server_name="127.0.0.1", server_port=8008, inbrowser=True, share=True)
{
"url": "https://github.com/THUDM/LongWriter",
"type": "github",
"title": "THUDM/LongWriter",
"picture": "https://avatars.githubusercontent.com/u/48590610?v=4",
"description": "Python / 4.0K lines of code.\nLongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment