I am running the run_mlm_flax-scripts on TPUv4. Will the parameter per_device_batch_size automatically scale when running on pods, or do the patitions needs to be defined? And how? To clarify: Lets say I am running a script with per_devic_batch_size=100 on 4 TPUs (a TPUv4-32). How big will the actual batch size be?
Related topics
| Topic | Replies | Views | Activity | |
|---|---|---|---|---|
| Regarding the argument `per_device_train_batch_size` | 0 | 77 | July 2, 2024 | |
| Trainer with TPUs | 3 | 2842 | April 13, 2022 | |
| FLAX - Training on Cloud TPU VM Pods (not single TPU devices) | 1 | 1436 | August 2, 2022 | |
| Per_device_train_batch_size in model parallelism | 2 | 103 | April 7, 2025 | |
| How to specify different batch sizes for different GPUs when training with rum_mlm.py? | 1 | 1113 | July 26, 2021 |