[源碼解析] PyTorch 分布式(16) — 使用異步執行實現批處理 RPC

[源碼解析] pytorch 分布式(16) — 使用異步執行實現批處理 rpc

目錄

[源碼解析] PyTorch 分布式(16) — 使用異步執行實現批處理 RPC0x00 摘要0x01 前言1.1 先決條件1.2 基礎知識1.3 代碼0x02 啟動2.1 總體啟動2.2 啟動參數服務器0x03 參數服務器0x04 Trainer0x05 對比0xFF 參考0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分布式的基本模塊,接下來我們通過幾篇文章來看看如何把這些模塊應用到實踐之中,順便把PyTorch分布式邏輯整體梳理一下。本文介紹如何使用異步執行操作來實現批處理 RPC,大家可以學習到PyTorch對參數服務器一個新的實現方式。

本文以IMPLEMENTING batch RPC PROCESSING using ASYNCHRONOUS EXECUTIONS的翻譯為基礎,加入了自己的理解。

0x01 前言1.1 先決條件

本文的先決條件如下:

PyTorch 分布式概述分布式 RPC 框架入門使用分布式 RPC 框架實現參數服務器RPC 異步執行裝飾器

本教程演示了如何使用@rpc.functions.async_execution 裝飾器構建批處理 RPC 應用程序,這有助于通過減少被阻塞的 RPC 線程的數量,并且在被調用方整合 CUDA 操作來加快訓練速度。這與使用 TorchServer 進行批量推理的想法相同。Batch RPC 有助于將動作整合到較少的 CUDA 操作中,從而攤銷開銷。

注意:本教程需要 PyTorch v1.6.0 或更高版本。

1.2 基礎知識

之前的教程已經展示了使用torch.distributed.rpc構建分布式訓練應用程序的步驟,但他們沒有詳細說明在處理 RPC 請求時被調用方會發生什么。從 PyTorch v1.5 開始,針對每個 RPC 請求,被調用者都會啟動一個線程來執行該請求中的函數,該線程會阻塞直到該函數返回。這適用于許多用例,但有一個問題:如果用戶函數在 IO 上阻塞,例如使用嵌套的 RPC 調用或信號(例如等待不同的 RPC 請求來解除阻塞),則被調用者上的 RPC 線程將不得不空閑等待,直到 IO 完成或信號(signal事件發生。因此,RPC 被調用者使用的線程可能會使用比實際需要更多。造成這個問題的原因是RPC把用戶函數當成黑盒,對函數中發生的事情知之甚少。為了讓用戶函數能夠讓出和釋放 RPC 線程,需要向 RPC 系統提供更多的提示。

從 v1.6.0 開始,PyTorch 通過引入兩個新概念來解決這個問題:

torch.futures.Future 封裝了一個異步執行,同時也支持安裝回調函數。@rpc.functions.async_execution 裝飾器,它允許應用程序告訴被調用者,本目標函數將返回一個future,并且可以在執行過程中多次暫停和yield。

使用這兩個工具,應用程序代碼可以將用戶函數分解為多個較小的函數,將它們鏈接在一起作為Future 對象的回調方法,并返回包含最終結果的 Future給調用者。在被調用方,在獲取Future對象時,它也會安裝后續的 RPC 響應處理作為回調方法,這些回調會在最終結果準備好時被觸發。這樣,被調用者不再需要阻塞一個線程,只是等待最終返回值準備好就行。 簡單的例子請參考@rpc.functions.async_execution的API文檔 。

除了減少被調用者的空閑線程數量外,這些工具還使批處理 RPC 處理更容易、更快。本教程演示了如何使用@rpc.functions.async_execution 裝飾器構建分布式批量更新參數服務器和批量處理強化學習應用程序 。

注:我們不考慮強化學習的領域,那樣會影響我們的思路,牽扯精力

1.3 代碼

因為原文主要是強化學習代碼講解,而我們只關注普通分布式批量更新參數服務器,所以需要看原始代碼。

代碼位于 https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py。先全部摘錄如下:

代碼語言:JavaScript代碼運行次數:0運行復制

import osimport threadingfrom datetime import datetimeimport torchimport torch.distributed.rpc as rpcimport torch.multiprocessing as mpimport torch.nn as nnfrom torch import optimimport torchvisionbatch_size = 20image_w = 64image_h = 64num_classes = 30batch_update_size = 5num_batches = 6def timed_log(text):    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")class BatchUpdateParameterServer(object):    def __init__(self, batch_update_size=batch_update_size):        self.model = torchvision.models.resnet50(num_classes=num_classes)        self.lock = threading.Lock()        self.future_model = torch.futures.Future()        self.batch_update_size = batch_update_size        self.curr_update_size = 0        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)        for p in self.model.parameters():            p.grad = torch.zeros_like(p)    def get_model(self):        return self.model    @staticmethod    @rpc.functions.async_execution    def update_and_fetch_model(ps_rref, grads):        self = ps_rref.local_value()        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")        for p, g in zip(self.model.parameters(), grads):            p.grad += g        with self.lock:            self.curr_update_size += 1            fut = self.future_model            if self.curr_update_size >= self.batch_update_size:                for p in self.model.parameters():                    p.grad /= self.batch_update_size                self.curr_update_size = 0                self.optimizer.step()                self.optimizer.zero_grad()                fut.set_result(self.model)                timed_log("PS updated model")                self.future_model = torch.futures.Future()        return futclass Trainer(object):    def __init__(self, ps_rref):        self.ps_rref = ps_rref        self.loss_fn = nn.MSELoss()        self.one_hot_indices = torch.LongTensor(batch_size)                                     .random_(0, num_classes)                                     .view(batch_size, 1)    def get_next_batch(self):        for _ in range(num_batches):            inputs = torch.randn(batch_size, 3, image_w, image_h)            labels = torch.zeros(batch_size, num_classes)                         .scatter_(1, self.one_hot_indices, 1)            yield inputs.cuda(), labels.cuda()    def train(self):        name = rpc.get_worker_info().name        m = self.ps_rref.rpc_sync().get_model().cuda()        for inputs, labels in self.get_next_batch():            timed_log(f"{name} processing one batch")            self.loss_fn(m(inputs), labels).backward()            timed_log(f"{name} reporting grads")            m = rpc.rpc_sync(                self.ps_rref.owner(),                BatchUpdateParameterServer.update_and_fetch_model,                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),            ).cuda()            timed_log(f"{name} got updated model")def run_trainer(ps_rref):    trainer = Trainer(ps_rref)    trainer.train()def run_ps(trainers):    timed_log("Start training")    ps_rref = rpc.RRef(BatchUpdateParameterServer())    futs = []    for trainer in trainers:        futs.append(            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))        )    torch.futures.wait_all(futs)    timed_log("Finish training")def run(rank, world_size):    os.environ['MASTER_ADDR'] = 'localhost'    os.environ['MASTER_PORT'] = '29500'    options=rpc.TensorPipeRpcBackendOptions(        num_worker_threads=16,        rpc_timeout=0  # infinite timeout     )    if rank != 0:        rpc.init_rpc(            f"trainer{rank}",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        # trainer passively waiting for ps to kick off training iterations    else:        rpc.init_rpc(            "ps",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        run_ps([f"trainer{r}" for r in range(1, world_size)])    # block until all rpcs finish    rpc.shutdown()if __name__=="__main__":    world_size = batch_update_size + 1    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

0x02 啟動

我們首先看看如何啟動。

2.1 總體啟動

我們假設有一個master(rank 0),一個worker。Master 之上運行的是參數服務器,worker 之上是訓練代碼。

代碼語言:javascript代碼運行次數:0運行復制

def run(rank, world_size):    os.environ['MASTER_ADDR'] = 'localhost'    os.environ['MASTER_PORT'] = '29500'    options=rpc.TensorPipeRpcBackendOptions(        num_worker_threads=16,        rpc_timeout=0  # infinite timeout     )    if rank != 0:        rpc.init_rpc( # 訓練代碼            f"trainer{rank}",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        # trainer passively waiting for ps to kick off training iterations    else:        rpc.init_rpc( # 參數服務器            "ps",             rank=rank,            world_size=world_size,            rpc_backend_options=options        )        run_ps([f"trainer{r}" for r in range(1, world_size)])    # block until all rpcs finish    rpc.shutdown()if __name__=="__main__":    world_size = batch_update_size + 1    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

邏輯如下圖:

代碼語言:javascript代碼運行次數:0運行復制

             torch.multiprocessing.spawn                        +                        |                        |           +------------+-------------------------------------------------           |                                                             |           |                                                             |           v                                                             v+----------+----------------------------------------------+ +------------+----------------+| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 ||                                                         | |                             ||                                                         | |                             ||                     rpc.init_rpc                        | |         rpc.init_rpc        ||                                                         | |                             ||                                                         | |                             ||  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |                             ||                                                         | |                             ||                                                         | |                             |+---------------------------------------------------------+ +-----------------------------+

2.2 啟動參數服務器

run_ps 啟動了參數服務器和trainer。注意,這里在參數服務器之中啟動 trainer,即,master 不僅僅有一個參數服務器,還負責通過 rpc 來驅動trainer上的訓練循環

代碼語言:javascript代碼運行次數:0運行復制

def run_ps(trainers):    timed_log("Start training")    ps_rref = rpc.RRef(BatchUpdateParameterServer())    futs = []    for trainer in trainers: # trainer 是字符串,比如"trainer1"        futs.append(            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) # 運行run_trainer        )    torch.futures.wait_all(futs)    timed_log("Finish training")    def run_trainer(ps_rref):    trainer = Trainer(ps_rref)    trainer.train() # 調用 Trainer 的方法   

具體拓展如下:

這里沒有給出參數服務器和trainer的邏輯,我們會在后續分析之后陸續給出。trainer 也只給出了一個。

[源碼解析] PyTorch 分布式(16) — 使用異步執行實現批處理 RPC

0x03 參數服務器

上面圖中沒有給出具體參數服務器代碼,我們接下來就分析一下。

這里考慮具有一個參數服務器 (PS) 和多個trainer的同步訓練應用程序。在這個應用中,PS 持有參數并等待所有訓練器報告梯度。在每次迭代中,它等待直到從所有訓練器接收梯度,然后一次性更新所有參數。

下面的代碼顯示了 PS 類的實現。

PS初始化時候生成了常規SGB優化器,不是分布式優化器,而且優化器是在PS之上。update_and_fetch_model方法被 @rpc.functions.async_execution所裝飾,將由trainer調用。每次調用都會返回一個Future對象,該對象將被用來處理更新后的模型。大多數訓練器發起的調用只是累積梯度到 .grad成員變量 ,然后立即返回,并在 PS 上產生 RPC 線程。最后到達的訓練器將觸發優化器步驟并消耗所有先前上報的梯度。然后它使用更新后的模型來設置future_model,這是依靠通過Future對象來依次通知來自其他訓練者的先前請求,并將更新后的模型發送給所有訓練者。

具體代碼如下:

代碼語言:javascript代碼運行次數:0運行復制

batch_size = 20image_w = 64image_h = 64num_classes = 30batch_update_size = 5num_batches = 6def timed_log(text):    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")class BatchUpdateParameterServer(object):    def __init__(self, batch_update_size=batch_update_size):        self.model = torchvision.models.resnet50(num_classes=num_classes)        self.lock = threading.Lock()        self.future_model = torch.futures.Future()        self.batch_update_size = batch_update_size        self.curr_update_size = 0        # 重點:這里是常規SGB優化器,不是分布式優化器        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)        for p in self.model.parameters():            p.grad = torch.zeros_like(p)    def get_model(self):        return self.model    @staticmethod    @rpc.functions.async_execution # trainer會直接調用    def update_and_fetch_model(ps_rref, grads):        self = ps_rref.local_value()        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")        for p, g in zip(self.model.parameters(), grads): # 得到            p.grad += g # 累積梯度        with self.lock:            self.curr_update_size += 1            fut = self.future_model            if self.curr_update_size >= self.batch_update_size:                # 最后到達的訓練器將觸發優化器步驟并消耗所有先前上報的梯度。                for p in self.model.parameters():                    p.grad /= self.batch_update_size                self.curr_update_size = 0                self.optimizer.step() # 更新模型                self.optimizer.zero_grad()                fut.set_result(self.model) # 將更新后的模型發送給所有訓練者                timed_log("PS updated model")                self.future_model = torch.futures.Future() # 使用更新后的模型來設置future_model        return fut # 該對象將被用來處理更新后的模型

邏輯拓展如下,這里省略了參數服務器生成trainer的步驟:

手機如下:

[源碼解析] PyTorch 分布式(16) — 使用異步執行實現批處理 RPC

0x04 Trainer

對于訓練器,它們都使用來自 PS 的相同參數集進行初始化。在每次迭代中執行如下操作:

每個訓練器首先運行前向和后向傳播以在本地生成梯度。然后,每個訓練器使用 RPC 向 PS 報告其梯度,并通過同一 RPC 請求的返回值取回更新后的參數。

在訓練器的實現中,目標函數是否被標記 @rpc.functions.async_execution是沒有區別的。訓練器只需使用 rpc_sync 調用update_and_fetch_model,其將阻塞訓練器,直到返回更新的模型。

可以看到,參數服務器存儲模型,模型可以返回到trainer。

代碼語言:javascript代碼運行次數:0運行復制

class Trainer(object):    def __init__(self, ps_rref):        self.ps_rref = ps_rref        self.loss_fn = nn.MSELoss()        self.one_hot_indices = torch.LongTensor(batch_size)                                     .random_(0, num_classes)                                     .view(batch_size, 1)    def get_next_batch(self):        for _ in range(num_batches):            inputs = torch.randn(batch_size, 3, image_w, image_h)            labels = torch.zeros(batch_size, num_classes)                         .scatter_(1, self.one_hot_indices, 1)            yield inputs.cuda(), labels.cuda()    def train(self):        name = rpc.get_worker_info().name        # 從參數服務器獲取模型        m = self.ps_rref.rpc_sync().get_model().cuda()        for inputs, labels in self.get_next_batch():            timed_log(f"{name} processing one batch")            # 利用模型來前向傳播/反向傳播            self.loss_fn(m(inputs), labels).backward()            timed_log(f"{name} reporting grads")            # 調用參數服務器的函數來提交梯度            m = rpc.rpc_sync( # rpc_sync 操作完成之后,m就是最新模型了                self.ps_rref.owner(),                BatchUpdateParameterServer.update_and_fetch_model,                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),            ).cuda()            timed_log(f"{name} got updated model")

拓展邏輯如下:

參數服務器的run_trainer 方法會直接調用 trainer.train() 方法來執行一步step。train 方法之中,會調用 self.ps_rref.rpc_sync().get_model().cuda() 從參數服務器獲得模型,放到本地設備之上(圖上是雙向箭頭,表示這是一個get/return動作,需要把模型存儲在worker本地)。調用 self.loss_fn(m(inputs), labels).backward() 來進行前向傳播/反向傳播。調用參數服務器的 update_and_fetch_model 函數來提交梯度,這里使用了異步RPC。參數服務器的 update_and_fetch_model 之中,進行梯度累積,模型更新是通過PS之上常規SGD優化器完成,最后調用 fut.set_result(self.model) 來發布新模型給trainer。在trainer 之中,就是 m = rpc.rpc_sync(…) 這個賦值之后,m 是最新模型了。

[源碼解析] PyTorch 分布式(16) — 使用異步執行實現批處理 RPC

0x05 對比

前文結尾,我們對比參數服務器的經典實現 ps-lite 和 前兩篇實現的參數服務器。

ps-lite 是類似傳統服務器實現,有自己主動的業務循環,可以響應用戶的顯式請求,也有自己明確的邏輯,本地也有自己的KV存儲。PyTorch 前兩篇官方文檔(本系列前兩篇文章)之中,參數服務器則是另外一種思路: 參數服務器上沒有主動的循環,沒有KV存儲,沒有服務器邏輯,而是可以直接存儲業務模型,ps 會把業務模型需要優化的參數返回給trainer 之上的 DistributedOptimizer。業務驅動由trainer完成:train loop代碼在trainer 之中,DistributedOptimizer 在trainer 之中,DistributedOptimizer 負責進行分布式優化。本文又與上面不同,看起來更像是ps-lite,但是又糅合了RPC實現: ps進程會啟動trainer的訓練循環。每個迭代之中,trainer 會從參數服務器獲取最新模型,前向操作/后向傳播都在trainer 完成。trainer 會通過異步RPC把梯度提交給參數服務器。模型更新是通過PS之上常規SGD優化器完成。模型更新之后通過異步RPC把模型再次分發給trainer。

不得不說,官方這幾篇文章快把各種實現方式玩出花來了,大家可以依據自己業務特點來參考實現。

0xFF 參考

IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS

? 版權聲明
THE END
喜歡就支持一下吧
點贊10 分享