Python 的 argparse 模块的作用,以及分享一个通用代码模板

Python
48
0
0
2024-11-14
标签   Python库

argparse 是 Python 内置的一个用于命令项选项与参数解析的模块。它的作用是帮助我们处理命令行输入,轻松编写用户友好的命令行接口。

在这里插入图片描述

命令行接口的需求

  • 假设您编写了一个 Python 脚本,您希望用户能够在运行脚本时提供一些选项或参数。例如,您的脚本可能需要从命令行获取文件路径、模型参数、指定输出目录等。
  • 使用 argparse 模块,可以轻松编写用户友好的命令行界面。程序定义了它需要的参数,argparse 就会找出如何从 sys.argv 中解析出这些参数。argparse 模块还会自动生成帮助和使用信息。如果用户传入无效的参数,argparse 会显示错误消息,帮助用户正确使用程序。

我们根据一个好的代码模块来学习 argparse 模块的使用:

这份代码源自:https://github.com/XinyuanWangCS/PromptAgent/blob/main/src/main.py

import argparse


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def config():
    parser = argparse.ArgumentParser(description='Process prompt search agent arguments')
    parser.add_argument('--task_name', type=str, default='bigbench',  help='This is consistent to the task file names in tasks folder. The default is bigbench task.')  
    parser.add_argument('--search_algo', type=str, default='mcts', choices=['mcts', 'beam'], help='Prompt search algorithm. Choose from \'mcts\' and \'beam\'.')    
    parser.add_argument('--batch_size', type=int, default=5, help='Batch size depending on the memory and model size')
    parser.add_argument('--depth_limit', type=int, default=5, help="The max depth of a single searching path.")
    parser.add_argument('--train_size', type=int, default=None, help="The dataset that sample batches from.")
    parser.add_argument('--eval_size', type=int, default=50, help="Calculate reward on this set.")
    parser.add_argument('--test_size', type=int, default=0, help="Test set size.")
    parser.add_argument('--seed', type=int, default=42, help="The seed to shuffle the dataset.")
    parser.add_argument('--train_shuffle', type=str2bool, default=True, help='Shuffle training set')

    # Search
    parser.add_argument('--init_prompt', type=str, default="Let's solve the problem.", help='Initial prompt written by human.')
    parser.add_argument('--iteration_num', type=int, default=12, help='MCTS iteration number.')
    parser.add_argument('--expand_width', type=int, default=3, help="The number of batches sampled in each expansion.")
    parser.add_argument('--num_new_prompts', type=int, default=1, help="The number of new prompts sampled in each batch.")
    parser.add_argument('--post_instruction', type=str2bool, default=False, help="True: the position of instruction is question+instruction; \nFalse: the position of instruction is instruction+question")

    # MCTS
    parser.add_argument('--min_depth', type=int, default=2, help="Early stop depth: early stop is only applied when depth is larger than min_depth.")
    parser.add_argument('--w_exp', type=float, default=2.5, help="Weight of MCTS.")

    # World Model
    parser.add_argument('--pred_model', type=str, default='gpt-3.5-turbo', help='The base model that makes predictions.')
    parser.add_argument('--optim_model', type=str, default='gpt-4', help='Prompt optimizer.') 
    parser.add_argument('--pred_temperature', type=float, default=0.0)
    parser.add_argument('--optim_temperature', type=float, default=1.0)
    
    # Others
    parser.add_argument('--log_dir', type=str, default='../logs/', help='Log directory.')
    parser.add_argument('--data_dir', type=str, default=None, help='Path to the data file (if needed)')
    parser.add_argument('--api_key', type=str, default=None, help='OpenAI api key or PaLM 2 api key')

    # BeamSearch
    parser.add_argument('--beam_width', type=int, default=3)

    return parser.parse_args()


if __name__ == '__main__':
    args = config()
    print(args, type(args), "\n")
    args = vars(args)
    print(args, type(args), "\n")
    print(f"pred_model: {args['pred_model']}")
    print(f"log_dir: {args['log_dir']}")

实现了一个名为 str2bool 的辅助函数。它接收字符串参数 v 并将其转换为布尔值。如果 v 已经是布尔值,则按原样返回。否则,它会检查 v 的小写版本是否与函数中定义的 true 或 false 表示相匹配。如果匹配,则返回相应的布尔值。如果不符合任何条件,将引发 argparse.ArgumentTypeError 并给出错误信息。

实现一个名为 config() 的函数,用于设置来自 argparse.ArgumentParser 的参数解析器对象。它负责定义和处理命令行参数。argparse 模块对命令行接口的支持是围绕 argparse.ArgumentParser 实例构建的。它是参数规范的容器,具有适用于整个解析器的选项。

位置参数:这些参数是在命令行中按照顺序传递的,不带前缀。例如,parser.add_argument("filename") 表示一个位置参数,用户需要提供一个文件名。

选项参数:这些参数通常以 - 或 -- 开头,可以接受各种值。以下是一些常见的选项参数类型:

  • -c 或 --count:带值的选项,用户可以提供一个计数值。
  • -v 或 --verbose:开关标志,表示是否启用详细输出。
  • -h 或 --help:获取帮助信息。

参数值的类型:

  • int:将参数值自动转换为整数。
  • float:将参数值自动转换为浮点数。
  • str:默认类型,接受字符串值。
  • bool:布尔类型,通常用于开关标志。

ArgumentParser.add_argument() 方法将单独的参数规范附加到解析器上。它支持选项参数、接受值的选项和开/关标志。

  • type:命令行参数应该被转换成的数据类型。例如,int、float、str 等。如果不指定 type,默认是字符串类型。
  • help:参数的帮助信息。当用户请求帮助时,这个描述会显示在命令行用法字符串和各种参数的帮助消息之间。编写清晰、简洁的帮助信息对用户非常重要。
  • default:如果用户未提供某个参数,将使用默认值。

parser.add_argument 依次为解析器添加了各个参数。例如,第一个参数名称为 “–task_name”,这意味着从命令行运行脚本时,可将其作为 --task_name 传递。它需要一个字符串值(type=str),默认值为 “bigbench”。帮助参数提供了参数的说明。类似地,剩下的代码为解析器添加了更多参数,每个参数都有各自的名称、数据类型、默认值和帮助说明。

添加所有参数后,会调用 parser.parse_args() 来解析运行脚本时提供的命令行参数,并将提取的数据放入 argparse.Namespace 对象中。这里注意一下:parser.parse_args() 方法会检查通过命令行传入的参数,并将它们转换为一个命名空间(argparse.Namespace)。如果传入的参数不符合预定义的规则(例如,缺少必需的参数或者参数格式不正确),它会自动显示错误信息并退出程序。

使用了内置函数 vars() 来处理前一步得到的 argparse.Namespace 对象 args。vars() 函数返回对象的 __dict__ 属性,这是一个包含了对象所有属性及其值的字典。因此,通过调用 vars(args),将命名空间对象转换成了一个字典。这样做有几个好处:

  • 可读性:使用字典可使得后续代码更加易读,因为可以直接通过键来访问参数值,而不是通过属性。
  • 灵活性:字典提供了更多操作和遍历元素的方法,使得处理复杂情况时更加灵活。

总的来说,这段代码使用 argparse 设置了一个参数解析器,定义了多个命令行参数及其类型、默认值和帮助信息,解析了所提供的参数,最终以字典形式返回。

代码的运行结果如下:

Namespace(api_key=None, batch_size=5, beam_width=3, data_dir=None, depth_limit=5, eval_size=50, expand_width=3, init_prompt="Let's solve the problem.", iteration_num=12, log_dir='../logs/', min_depth=2, num_new_prompts=1, optim_model='gpt-4', optim_temperature=1.0, post_instruction=False, pred_model='gpt-3.5-turbo', pred_temperature=0.0, search_algo='mcts', seed=42, task_name='bigbench', test_size=0, train_shuffle=True, train_size=None, w_exp=2.5) <class 'argparse.Namespace'> 

{'task_name': 'bigbench', 'search_algo': 'mcts', 'batch_size': 5, 'depth_limit': 5, 'train_size': None, 'eval_size': 50, 'test_size': 0, 'seed': 42, 'train_shuffle': True, 'init_prompt': "Let's solve the problem.", 'iteration_num': 12, 'expand_width': 3, 'num_new_prompts': 1, 'post_instruction': False, 'min_depth': 2, 'w_exp': 2.5, 'pred_model': 'gpt-3.5-turbo', 'optim_model': 'gpt-4', 'pred_temperature': 0.0, 'optim_temperature': 1.0, 'log_dir': '../logs/', 'data_dir': None, 'api_key': None, 'beam_width': 3} <class 'dict'> 

pred_model: gpt-3.5-turbo
log_dir: ../logs/

在这里插入图片描述

对于 argparse 模块,总结一下:

  • 易于使用:argparse 是 Python 中一个非常实用的模块,用于解析命令行参数。argparse 提供了一种简单的方式来定义和解析命令行参数,使得我们的 Python 脚本能够更好地与命令行接口集成。您可以创建用户友好的命令行接口,使我们的 Python 脚本更易于使用和管理。
  • 灵活性:通过使用 argparse,我们可以轻松地解析命令行参数。它允许您定义各种选项、参数和子命令,从而使您的程序更具灵活性。您可以根据需要添加或修改选项,而无需更改源代码。
  • 自动生成帮助和用法信息:argparse 能够自动生成帮助和用法消息文本。当用户运行您的程序时,只需使用 --help 或 -h 选项,就能获得详细的帮助信息,包括可用选项、参数和用法示例。
  • 错误处理:argparse 在用户向程序传入无效参数时会发出错误消息。这有助于防止用户输入错误的选项或参数,从而提高程序的健壮性。
  • 位置参数和可选参数的灵活组合:argparse 允许您定义位置参数和可选参数,以及它们的组合。位置参数是根据其在命令行中出现的位置来处理的,而可选参数则可以根据用户的选择进行设置。

📚️ 参考链接: