The distributed computing tests (test_dist.py) verify the distributed training functionality, including rank management, device detection, and process group initialization.
deftest_get_torch_device_type(self):"""Test get_torch_device_type function."""device_type=dist.get_torch_device_type()assertisinstance(device_type,str)# Should be one of the common device typesassertdevice_typein["cpu","cuda","xpu","mps"]
deftest_get_torch_device(self):"""Test get_torch_device function."""device=dist.get_torch_device()assertdeviceisnotNone# Should contain the device typeassertdist.get_torch_device_type()instr(device)
deftest_query_environment(self):"""Test query_environment function."""env_info=dist.query_environment()assertisinstance(env_info,dict)# Should contain some basic informationassertlen(env_info)>0
deftest_get_dist_info(self):"""Test get_dist_info function."""dist_info=dist.get_dist_info()assertisinstance(dist_info,dict)# Should contain rank and world_sizeassert"rank"indist_infoassert"world_size"indist_info