Last active
September 12, 2024 04:01
-
-
Save jussker/cc2557732a3c928e62392b50d9168dcb to your computer and use it in GitHub Desktop.
ChatGLM Batch API CLI 例子
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# @Python Version: python 3.11.6 | |
# @Time : 2024/09/11 23:40 | |
# @Author : Jin | |
"""说明: | |
这个脚本是一个使用GLM Batch API的例子. 你可以使用这个脚本上传文件, 并且下载结果. | |
官网参考: https://bigmodel.cn/dev/howuse/batchapi | |
## 安装步骤 | |
1. 首先, 你需要一个GLM API Key. 你可以在[GLM官网](https://bigmodel.cn)注册一个账号, 并且获取一个API Key. | |
2. 然后, 你需要将API Key设置为环境变量 `GLM_API_KEY` 或者在运行脚本时输入API Key. | |
3. 再然后,为 `demo_glm_batch.py` 脚本添加可执行权限: | |
``` | |
chmod +x demo_glm_batch.py | |
``` | |
4. 再然后,安装依赖: | |
需要安装openai的python sdk. 请使用以下命令安装: | |
``` | |
pip install openai>=1.43.1 | |
``` | |
5. 最后, 运行脚本: | |
先创建一个参考文件, 例如 `upload_example.jsonl`: | |
``` | |
./demo_glm_batch.py --examples | |
``` | |
然后, 上传文件: | |
``` | |
./demo_glm_batch.py --files .glm_history/upload_example.jsonl | |
``` | |
可以看官网控制台信息判断是否是否上传成功,当然也可以查看任务状态来判断. | |
https://bigmodel.cn/console/batch/task | |
最后, 下载结果: | |
``` | |
./demo_glm_batch.py --job_id <job_id> | |
``` | |
""" | |
# 检查是否安装了openai的python sdk | |
# 使用built-in检查 | |
import importlib.util | |
if importlib.util.find_spec("openai") is None: | |
print( | |
"Error: Please install the openai python sdk first. You can run `pip install openai>=1.43.1` to install it." | |
"" | |
) | |
exit(1) | |
import argparse | |
import getpass | |
import json | |
import logging | |
import os | |
import time | |
import uuid | |
from httpx import Client, HTTPTransport, Proxy | |
from openai import OpenAI | |
# 配置日志记录器 | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
# source: https://bigmodel.cn/dev/howuse/batchapi | |
UPLOAD_EXAMPLES = [ | |
{ | |
"custom_id": "request-1", | |
"method": "POST", | |
"url": "/v4/chat/completions", | |
"body": { | |
"model": "glm-4", | |
"messages": [ | |
{"role": "system", "content": "你是一个意图分类器."}, | |
{ | |
"role": "user", | |
"content": '#任务:对以下用户评论进行情感分类和特定问题标签标注,只输出结果,# 评论:review = "订单处理速度太慢,等了很久才发货。"# 输出格式:\'\'\'{"分类标签": " ", "特定问题标注": " " } \'\'\'', | |
}, | |
], | |
}, | |
}, | |
{ | |
"custom_id": "request-2", | |
"method": "POST", | |
"url": "/v4/chat/completions", | |
"body": { | |
"model": "glm-4", | |
"messages": [ | |
{"role": "system", "content": "你是一个意图分类器."}, | |
{ | |
"role": "user", | |
"content": '#任务:对以下用户评论进行情感分类和特定问题标签标注,只输出结果,# 评论:review = ",商品有点小瑕疵,不过客服处理得很快,总体满意。",# 输出格式:\'\'\'{",分类标签": " ", "特定问题标注": " " } \'\'\'', | |
}, | |
], | |
}, | |
}, | |
{ | |
"custom_id": "request-3", | |
"method": "POST", | |
"url": "/v4/chat/completions", | |
"body": { | |
"model": "glm-4", | |
"messages": [ | |
{"role": "system", "content": "你是一个意图分类器."}, | |
{ | |
"role": "user", | |
"content": '#任务:对以下用户评论进行情感分类和特定问题标签标注,只输出结果,# 评论:review = "这款产品性价比很高,非常满意。"# 输出格式:\'\'\'{"分类标签": " ", "特定问题标注": " " } \'\'\'', | |
}, | |
], | |
}, | |
}, | |
] | |
def init_api_key(): | |
API_KEY = os.environ.get("GLM_API_KEY", "") | |
if API_KEY == "" and os.environ.get("GLM_API_KEY", ""): | |
API_KEY = os.environ.get("GLM_API_KEY") | |
if API_KEY == "": | |
# 输入密码 使用 password 输入 | |
API_KEY = getpass.getpass("Please input your GLM API Key: ") | |
return API_KEY | |
def create_glm_instance(): | |
# 创建 GLM API 实例 | |
BASE_URL = "https://open.bigmodel.cn/api/paas/v4" | |
API_KEY = init_api_key() | |
# 自定义一个http client | |
enable_proxy = False | |
_proxy = Proxy(url=os.environ.get("http_proxy")) if enable_proxy else None | |
_httpx_client = Client(transport=HTTPTransport(retries=3, proxy=_proxy)) | |
glm = OpenAI(api_key=API_KEY, base_url=BASE_URL, http_client=_httpx_client) | |
return glm | |
def upload_files(glm, files): | |
"""Uploads a list of files and returns a list of their corresponding file IDs. | |
Args: | |
files (list): A list of file paths to be uploaded. | |
Returns: | |
list: A list of file IDs corresponding to the uploaded files. | |
""" | |
upload_file_ids = [] | |
for file in files: | |
upload_file = glm.files.create(file=open(file, "rb"), purpose="batch") | |
upload_file_ids.append(upload_file.id) | |
logging.info(f"File {file} uploaded. File ID: {upload_file.id}") | |
return upload_file_ids | |
def submit_batch_tasks(glm, upload_file_ids): | |
"""Submits batch tasks for sentiment classification. | |
Args: | |
upload_file_ids (list): A list of upload file IDs. | |
Returns: | |
list: A list of task IDs created for the batch tasks. | |
""" | |
task_ids = [] | |
for upload_file_id in upload_file_ids: | |
batch_task = glm.batches.create( | |
completion_window="24h", | |
input_file_id=upload_file_id, | |
endpoint="/v4/chat/completions", | |
metadata={"description": "Sentiment classification"}, | |
) | |
task_ids.append(batch_task.id) | |
logging.info(f"Task {batch_task.id} created. Related file ID: {upload_file_id}") | |
return task_ids | |
def check_task_status(glm, task_ids): | |
"""Checks the status of the given task IDs. | |
Args: | |
task_ids (list): A list of task IDs. | |
Returns: | |
dict: A dictionary mapping task IDs to their corresponding status. | |
""" | |
task_states = {} | |
for task_id in task_ids: | |
task = glm.batches.retrieve(task_id) | |
task_states[task_id] = task.status | |
return task_states | |
def download_results(glm, task_ids): | |
"""Downloads the completed results for the given task IDs. | |
Args: | |
task_ids (list): A list of task IDs. | |
Returns: | |
None | |
""" | |
# save jobs to a file | |
# 检查是否存在本地文件夹 `.glm_hisotry` | |
if not os.path.exists(".glm_history"): | |
os.makedirs(".glm_history") | |
if not os.path.exists(".glm_history/download"): | |
os.makedirs(".glm_history/download") | |
for task_id in task_ids: | |
task = glm.batches.retrieve(task_id) | |
if task.status.lower() == "completed": | |
content = glm.files.content(task.output_file_id) | |
content.write_to_file(f".glm_history/download/batch_result_{task_id}.jsonl") | |
logging.info( | |
( | |
f"Task {task_id} completed. Result downloaded to" | |
f" .glm_history/download/batch_result_{task_id}.jsonl" | |
) | |
) | |
else: | |
logging.error(f"Task {task_id} failed: {task.status}") | |
def record_job(files, upload_file_ids, task_ids, task_states, job_id=None): | |
tasks = [] | |
for file, file_id, task_id, task_state in zip( | |
files, upload_file_ids, task_ids, task_states | |
): | |
tasks.append( | |
{ | |
"file": file, | |
"file_id": file_id, | |
"task_id": task_id, | |
"task_state": task_state, | |
} | |
) | |
# save jobs to a file | |
# 检查是否存在本地文件夹 `.glm_hisotry` | |
if not os.path.exists(".glm_history"): | |
os.makedirs(".glm_history") | |
if not os.path.exists(".glm_history/job"): | |
os.makedirs(".glm_history/job") | |
# random a uuid | |
if job_id is None: | |
job_id = str(uuid.uuid4()) | |
with open(f".glm_history/job/jobs_{job_id}.jsonl", "w") as f: | |
for task in tasks: | |
f.write(json.dumps(task) + "\n") | |
logging.info(f"Jobs recorded to .glm_history/job/jobs_{job_id}.jsonl") | |
return job_id | |
def update_job(job_id, task_states): | |
# save jobs to a file | |
# 检查是否存在本地文件夹 `.glm_hisotry` | |
if not os.path.exists(".glm_history"): | |
os.makedirs(".glm_history") | |
if not os.path.exists(".glm_history/job"): | |
os.makedirs(".glm_history/job") | |
tasks = load_job(job_id) | |
for task in tasks: | |
task["task_state"] = task_states[task["task_id"]] | |
with open(f".glm_history/job/jobs_{job_id}.jsonl", "w") as f: | |
for task in tasks: | |
f.write(json.dumps(task) + "\n") | |
logging.info(f"Job updated to .glm_history/job/jobs_{job_id}.jsonl") | |
return job_id | |
# 从文件中读取任务记录 | |
def load_job(job_id): | |
job_file = f".glm_history/job/jobs_{job_id}.jsonl" | |
tasks = [] | |
with open(job_file, "r") as f: | |
for line in f: | |
tasks.append(json.loads(line)) | |
logging.info(f"Tasks loaded from {job_file}") | |
return tasks | |
# 列出所有的任务 | |
def list_jobs(): | |
# 检查是否存在本地文件夹 `.glm_hisotry` | |
if not os.path.exists(".glm_history"): | |
os.makedirs(".glm_history") | |
if not os.path.exists(".glm_history/job"): | |
os.makedirs(".glm_history/job") | |
jobs = {} | |
for root, dirs, files in os.walk(".glm_history/job"): | |
for file in files: | |
if file.endswith(".jsonl") and file.startswith("jobs_"): | |
# 获取任务ID | |
job_id = file.split("_")[1].split(".")[0] | |
jobs[job_id] = load_job(job_id) | |
return jobs | |
def main(files=None, job_id=None): | |
# 创建实例 | |
glm = create_glm_instance() | |
if job_id is None: | |
# 待上传的文件列表 | |
# 上传文件至 GLM | |
upload_file_ids = upload_files(glm, files) | |
# 使用上传后的文件ID创建 Batch 任务 | |
task_ids = submit_batch_tasks(glm, upload_file_ids) | |
# 初始化状态 | |
task_states = check_task_status(glm, task_ids) | |
# 记录任务 | |
job_id = record_job(files, upload_file_ids, task_ids, task_states) | |
job = load_job(job_id) | |
task_ids = [task["task_id"] for task in job] | |
# 等待任务完成 | |
while True: | |
time.sleep(5) | |
# 更新任务状态 | |
task_states = check_task_status(glm, task_ids) | |
update_job(job_id, task_states) | |
should_break = True | |
# 检查任务是否有在进行中的 | |
for task_id in task_ids: | |
if task_states[task_id].lower() in ( | |
"validating", | |
"in_progress", | |
"finalizing", | |
"cancelling", | |
): | |
should_break = False | |
break | |
if should_break: | |
break | |
for task_id, status in task_states.items(): | |
logging.info(f"Task {task_id} status: {status}") | |
# 下载任务结果 | |
download_results(glm, task_ids) | |
def _upload_jsonl_to_csv(jsonl_file, csv_file): | |
# 检查文件后缀 | |
if not jsonl_file.endswith(".jsonl"): | |
print(f"Error: The input file must be a JSONL file: {jsonl_file}") | |
return | |
if not csv_file.endswith(".csv"): | |
print(f"Error: The output file must be a CSV file: {csv_file}") | |
return | |
import csv | |
with open(jsonl_file, "r") as f: | |
data = f.readlines() | |
fieldnames = [ | |
"custom_id", | |
"method", | |
"url", | |
"body.model", | |
] | |
max_messages_len = max(map(lambda x: len(json.loads(x)["body"]["messages"]), data)) | |
for i in range(1, max_messages_len + 1): | |
fieldnames.extend([f"body.messages.{i}.role", f"body.messages.{i}.content"]) | |
with open(csv_file, "w") as f: | |
writer = csv.DictWriter(f, fieldnames=fieldnames) | |
writer.writeheader() | |
row = {} | |
for line in data: | |
raw_json = json.loads(line) | |
row["custom_id"] = raw_json["custom_id"] | |
row["method"] = raw_json["method"] | |
row["url"] = raw_json["url"] | |
row["body.model"] = raw_json["body"]["model"] | |
for i, message in enumerate(raw_json["body"]["messages"], 1): | |
row[f"body.messages.{i}.role"] = message["role"] | |
row[f"body.messages.{i}.content"] = message["content"] | |
writer.writerow(row) | |
def _download_jsonl_to_csv(jsonl_file, csv_file): | |
# 检查文件后缀 | |
if not jsonl_file.endswith(".jsonl"): | |
print(f"Error: The input file must be a JSONL file: {jsonl_file}") | |
return | |
if not csv_file.endswith(".csv"): | |
print(f"Error: The output file must be a CSV file: {csv_file}") | |
return | |
import csv | |
with open(jsonl_file, "r") as f: | |
data = f.readlines() | |
# {"response":{"status_code":200,"body":{"created":1726072098,"usage":{"completion_tokens":24,"prompt_tokens":72,"total_tokens":96},"model":"glm-4","id":"9012832980776182113","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"```json\n{\"分类标签\": \"负面\", \"特定问题标注\": \"物流配送问题\" }\n```"}}],"request_id":"request-1"}},"custom_id":"request-1","id":"batch_1833899973411479552"} | |
fieldnames = [ | |
"response.status_code", | |
"response.body.created", | |
"response.body.usage.completion_tokens", | |
"response.body.usage.prompt_tokens", | |
"response.body.usage.total_tokens", | |
"response.body.model", | |
"response.body.id", | |
"response.body.request_id", | |
"custom_id", | |
"id", | |
] | |
max_choices_len = max( | |
map(lambda x: len(json.loads(x)["response"]["body"]["choices"]), data) | |
) | |
for i in range(1, max_choices_len + 1): | |
fieldnames.extend( | |
[ | |
f"body.choices.{i}.finish_reason", | |
f"body.choices.{i}.index", | |
f"body.choices.{i}.message.role", | |
f"body.choices.{i}.message.content", | |
] | |
) | |
with open(csv_file, "w") as f: | |
writer = csv.DictWriter(f, fieldnames=fieldnames) | |
writer.writeheader() | |
row = {} | |
for line in data: | |
raw_json = json.loads(line) | |
row["response.status_code"] = raw_json["response"]["status_code"] | |
row["response.body.created"] = raw_json["response"]["body"]["created"] | |
row["response.body.usage.completion_tokens"] = raw_json["response"]["body"][ | |
"usage" | |
]["completion_tokens"] | |
row["response.body.usage.prompt_tokens"] = raw_json["response"]["body"][ | |
"usage" | |
]["prompt_tokens"] | |
row["response.body.usage.total_tokens"] = raw_json["response"]["body"][ | |
"usage" | |
]["total_tokens"] | |
row["response.body.model"] = raw_json["response"]["body"]["model"] | |
row["response.body.id"] = raw_json["response"]["body"]["id"] | |
row["response.body.request_id"] = raw_json["response"]["body"]["request_id"] | |
row["custom_id"] = raw_json["custom_id"] | |
row["id"] = raw_json["id"] | |
for i, choice in enumerate(raw_json["response"]["body"]["choices"], 1): | |
row[f"body.choices.{i}.finish_reason"] = choice["finish_reason"] | |
row[f"body.choices.{i}.index"] = choice["index"] | |
row[f"body.choices.{i}.message.role"] = choice["message"]["role"] | |
row[f"body.choices.{i}.message.content"] = choice["message"]["content"] | |
writer.writerow(row) | |
def _upload_csv_to_jsonl(csv_file, jsonl_file): | |
# 检查文件后缀 | |
if not csv_file.endswith(".csv"): | |
print(f"Error: The output file must be a CSV file: {csv_file}") | |
return | |
if not jsonl_file.endswith(".jsonl"): | |
print(f"Error: The input file must be a JSONL file: {jsonl_file}") | |
return | |
import csv | |
with open(csv_file, "r") as f: | |
reader = csv.DictReader(f) | |
data = list(reader) | |
with open(jsonl_file, "w") as f: | |
for row in data: | |
payload = {} | |
payload["custom_id"] = row["custom_id"] | |
payload["method"] = row["method"] | |
payload["url"] = row["url"] | |
payload["body"] = {} | |
body = payload["body"] | |
body["model"] = row["body.model"] | |
body["messages"] = [] | |
for key in row: | |
if key.startswith("body.messages."): | |
_, _, i, k = key.split(".") | |
i = int(i) | |
if len(body["messages"]) < i: | |
body["messages"].append({}) | |
body["messages"][i - 1][k] = row[key] | |
f.write(json.dumps(payload, ensure_ascii=False) + "\n") | |
def _download_csv_to_jsonl(csv_file, jsonl_file): | |
# 检查文件后缀 | |
if not csv_file.endswith(".csv"): | |
print(f"Error: The output file must be a CSV file: {csv_file}") | |
return | |
if not jsonl_file.endswith(".jsonl"): | |
print(f"Error: The input file must be a JSONL file: {jsonl_file}") | |
return | |
import csv | |
with open(csv_file, "r") as f: | |
reader = csv.DictReader(f) | |
data = list(reader) | |
with open(jsonl_file, "w") as f: | |
for row in data: | |
payload = {} | |
payload["response"] = {} | |
payload["custom_id"] = row["custom_id"] | |
payload["id"] = row["id"] | |
response = payload["response"] | |
response["status_code"] = row["response.status_code"] | |
response["body"] = {} | |
body = response["body"] | |
body["created"] = row["response.body.created"] | |
body["usage"] = {} | |
usage = body["usage"] | |
usage["completion_tokens"] = row["response.body.usage.completion_tokens"] | |
usage["prompt_tokens"] = row["response.body.usage.prompt_tokens"] | |
usage["total_tokens"] = row["response.body.usage.total_tokens"] | |
body["model"] = row["response.body.model"] | |
body["id"] = row["response.body.id"] | |
body["request_id"] = row["response.body.request_id"] | |
body["choices"] = [] | |
for key in row: | |
if key.startswith("body.choices."): | |
_, _, *k = key.split(".") | |
if len(k) == 2: | |
i, k = k | |
else: | |
i, *k = k | |
i = int(i) | |
if len(body["choices"]) < i: | |
body["choices"].append({}) | |
if k == "index": | |
body["choices"][i - 1]["index"] = row[key] | |
elif k == "finish_reason": | |
body["choices"][i - 1]["finish_reason"] = row[key] | |
if k[0] == "message": | |
if body["choices"][i - 1].get("message") is None: | |
body["choices"][i - 1]["message"] = {} | |
if ".".join(k) == "message.content": | |
body["choices"][i - 1]["message"]["content"] = row[key] | |
elif ".".join(k) == "message.role": | |
body["choices"][i - 1]["message"]["role"] = row[key] | |
f.write(json.dumps(payload, ensure_ascii=False) + "\n") | |
def jsonl_to_csv(file_type, jsonl_file, csv_file): | |
if file_type == "upload": | |
_upload_jsonl_to_csv(jsonl_file, csv_file) | |
elif file_type == "download": | |
_download_jsonl_to_csv(jsonl_file, csv_file) | |
def csv_to_jsonl(file_type, csv_file, jsonl_file): | |
if file_type == "upload": | |
_upload_csv_to_jsonl(csv_file, jsonl_file) | |
elif file_type == "download": | |
_download_csv_to_jsonl(csv_file, jsonl_file) | |
if __name__ == "__main__": | |
# 暂时停用logging | |
logging.disable(logging.CRITICAL) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--job_id", help="The job ID to be resumed.") | |
parser.add_argument( | |
"--files", | |
nargs="+", | |
help="The list of file paths to be processed in the batch task.", | |
) | |
parser.add_argument( | |
"--list-jobs", | |
action="store_true", | |
help="List all the jobs recorded in the history.", | |
) | |
parser.add_argument( | |
"--examples", | |
action="store_true", | |
help="Show examples of file data about `GLM Batch API`, upload and download.", | |
) | |
# --jsonl-to-csv 包含3个参数: filetype, jsonl_file, csv_file, | |
# 其中filetype参数是枚举值是upload或者download | |
parser.add_argument( | |
"--jsonl-to-csv", | |
nargs="+", | |
help='Convert JSONL to CSV. FILE_TYPE must be "upload" or "download". ' | |
"JSONL_FILE is required. CSV_FILE is optional.", | |
) | |
# --csv-to-jsonl 包含3个参数: filetype, csv_file, jsonl_file, | |
# 其中filetype参数是枚举值是upload或者download | |
parser.add_argument( | |
"--csv-to-jsonl", | |
nargs="+", | |
help='Convert CSV to JSONL. FILE_TYPE must be "upload" or "download". ' | |
"CSV_FILE and JSONL_FILE are required.", | |
) | |
parser.add_argument( | |
"--ask", | |
action="store_true", | |
help="Ask a question using the CLI.", | |
) | |
args = parser.parse_args() | |
if args.list_jobs: | |
# 列出所有的任务 | |
jobs = list_jobs() | |
for job_id, tasks in jobs.items(): | |
print(f"Job {job_id}:") | |
for task in tasks: | |
print( | |
( | |
f" |--Task {task['task_id']} - {task['task_state']}:" | |
f" {task['file']}" | |
) | |
) | |
# exit | |
exit(0) | |
if args.examples: | |
print("上传需要使用 .jsonl 文件格式, 每一行是一个json.\n") | |
print("上传的JSON列子:") | |
# 展示一个例子 | |
example = UPLOAD_EXAMPLES[0] | |
print(json.dumps(example, ensure_ascii=False)) | |
# Save a example file to .glm_history/upload_example.jsonl | |
print("是否创建了一个示例文件: .glm_history/upload_example.jsonl\n") | |
yes_or_no = input("是否创建一个示例文件? (y/n): ") | |
if yes_or_no.lower() in ("y", "yes"): | |
if not os.path.exists(".glm_history"): | |
os.makedirs(".glm_history") | |
with open(".glm_history/upload_example.jsonl", "w") as f: | |
for example in UPLOAD_EXAMPLES: | |
f.write(json.dumps(example, ensure_ascii=False) + "\n") | |
# exit | |
exit(0) | |
if args.jsonl_to_csv: | |
if len(args.jsonl_to_csv) < 2 or len(args.jsonl_to_csv) > 3: | |
parser.error("--json-to-csv requires 2 or 3 arguments") | |
file_type = args.jsonl_to_csv[0] | |
jsonl_file = args.jsonl_to_csv[1] | |
csv_file = args.jsonl_to_csv[2] if len(args.jsonl_to_csv) == 3 else None | |
if file_type not in ["upload", "download"]: | |
parser.error('FILE_TYPE must be either "upload" or "download"') | |
print(f"File type: {file_type}") | |
print(f"JSONL file: {jsonl_file}") | |
if csv_file: | |
print(f"CSV file: {csv_file}") | |
else: | |
csv_file = jsonl_file.replace(".jsonl", ".csv") | |
print( | |
f"Warning: No CSV file specified. Using same name as JSONL file: {csv_file}" | |
) | |
# 在这里添加转换逻辑 | |
jsonl_to_csv(file_type, jsonl_file, csv_file) | |
exit(0) | |
if args.csv_to_jsonl: | |
if len(args.csv_to_jsonl) < 2 or len(args.csv_to_jsonl) > 3: | |
parser.error("--csv-to-jsonl requires 2 or 3 arguments") | |
file_type = args.csv_to_jsonl[0] | |
csv_file = args.csv_to_jsonl[1] | |
jsonl_file = args.csv_to_jsonl[2] if len(args.csv_to_jsonl) == 3 else None | |
if file_type not in ["upload", "download"]: | |
parser.error('FILE_TYPE must be either "upload" or "download"') | |
print(f"File type: {file_type}") | |
print(f"CSV file: {csv_file}") | |
if jsonl_file: | |
print(f"JSONL file: {jsonl_file}") | |
else: | |
jsonl_file = csv_file.replace(".csv", ".jsonl") | |
print( | |
f"Warning: No JSONL file specified. Using same name as CSV file: {jsonl_file}" | |
) | |
# 在这里添加转换逻辑 | |
csv_to_jsonl(file_type, csv_file, jsonl_file) | |
exit(0) | |
# 启用logging | |
logging.disable(logging.NOTSET) | |
if args.job_id is not None: | |
if args.files is not None: | |
logging.warning("Job ID and files cannot be specified at the same time.") | |
main(job_id=args.job_id) | |
elif args.files is not None: | |
if args.job_id is not None: | |
# 覆盖任务 | |
main(files=args.files, job_id=args.job_id) | |
logging.warning("Job ID refers to an existing job. Overwriting the job.") | |
else: | |
main(files=args.files) | |
else: | |
logging.error("No files or job ID specified.") | |
parser.print_help() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment