class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()
 
    def forward(self, y_prime_t, y_t):
        ey_t = (y_t - y_prime_t)/3 # divide by 3 to avoid numerical overflow in cosh
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))