'Memory usage in transforming fine tuning of GPTJ-6b to HuggingFace format

Following this tutorial using TPUs to fine tune GPTJ has worked well. https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md

Why would the step to transform to huggingface format using to_hf_weights.py have an issue with memory at 256MB - even after slimming has been applied?

The issue I filed is here: https://github.com/kingoflolz/mesh-transformer-jax/issues/209



Solution 1:[1]

Resolved by running this step on a standard machine (not TPU) with lots of mem.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Jonathan Hendler