发布于 

「技术」Pytorch 并行训练

介绍

常用术语

  • worker:常指 CPU
  • node:节点,通常对应一台完整的机器
  • nnodes:节点数量
  • node_rank:节点序号
  • nproc_per_node:节点上的进程数量,通常一个进程对应一个 GPU,故也表示 GPU 的数量
  • master_addr:master 的 IP 地址,也就是 rank=0 的 IP 地址
  • RANK:进程的序号,通常一个进程对应一个 GPU,全局,范围为 [0, sum(gpu)-1]
  • LOCAL_RANK:进程在节点上的序号,本地,范围为 [0, (local gpu)-1]
  • WORLD_SIZE:所有进程的和,一般等于 nnodes * nproc_per_node

数据并行

每个 worker 上复制一份模型,每个 batch 在多个 worker 之间分割,并定期汇总他们的梯度,从而保证权重版本一致

模型并行

把模型放到多个设备之上,分为流水线并行和张量并行

  • 流水线并行
    • 将模型的不同层放到不同设备之上
    • 层间并行
  • 张量并行
    • 将一层的运算分解为多个互不影响的子运算,并分配到不同设备上
    • 层内并行

两种并行方式正交互补

通信

通信原语有以下几种:

  • send
  • recv
  • broadcast
    • 将自身数据发送到集群中的其他节点
  • reduce
    • 精简操作,算符有 SUM/MIN/MAX/...
    • 每个节点获取一个输入元素数组并应用算符
  • all_reduce
    • 在所有节点上都应用相同的 reduce 操作
    • 等价于 reduce + boardcast
    • 疑问:所有节点数据相同吗?
      • 相同
  • gather
    • 将其它节点的数据收集到目标节点,返回一个列表
  • all_gather
    • 在所有节点上都应用相同的 gather 操作
    • 等价于 gather + boardcast
    • 疑问:同 reduce
  • scatter
    • 将数据的不同部分,按需发送给所有的节点
  • reduce_scatter
    • 将各节点的输入先进行求和,然后在第 0 维按卡数切分
  • all_to_all
    • 节点两两之间均发送消息,但发送缓冲区种不同的目标节点有不同的数据
  • barrier

张量并行

  • all_gather
  • 需要汇总所有数据,通信量大,常常在同节点内切分,从而使用较快的高带宽节点内通信

流水线并行

  • send / recv
  • 点对点通信,通信量小
  • 开始和结束时的 bubble
  • 优化器的跨设备同步(刷新)
  • 异步策略不需要刷新,但放松了权重更新语义

参考资料