- Saves the best n epochs states and then averages them in final model.
- implementation: https://github.com/btbpanda/CAFA5-protein-function-prediction-2nd-place/blob/main/protnn/swa.py#L8
- you can either do a simple average, or a weighted average
- the weighted average is based on the score you pass into the function
w = score / sum(self.scores[:n]) if weighted else 1 / n
- where the self.scores is appended to when you call
add_checkpoint(self, model, score=1, ):