import java.util.*;

public class Main {

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        if (n <= 0) {
            System.out.println("Invalid input");
            return;
        }

        int[] boxes = new int[n];
        String[] parts = sc.nextLine().trim().split(" ");
        if (parts.length != n) {
            System.out.println("Invalid input");
            return;
        }

        for (int i = 0; i < n; i++) {
            boxes[i] = Integer.parseInt(parts[i]);
            if (boxes[i] <= 0) {
                System.out.println("Invalid input");
                return;
            }
        }

        Map<String, Integer> memo = new HashMap<>();
        int result = maxPoints(boxes, memo);
        System.out.println(result);
    }
    static int maxPoints(int[] boxes, Map<String, Integer> memo) {
        return dp(new ArrayList<>(Arrays.stream(boxes).boxed().toList()), memo);
    }

    static int dp(List<Integer> list, Map<String, Integer> memo) {
        if (list.isEmpty()) return 0;

        String key = list.toString();
        if (memo.containsKey(key)) return memo.get(key);

        int max = 0;
        for (int i = 0; i < list.size(); ) {
            int j = i;
            while (j < list.size() && list.get(j).equals(list.get(i))) {
                j++;
            }

            int count = j - i;
            int score = count * count;

            List<Integer> newList = new ArrayList<>(list);
            newList.subList(i, j).clear();

            max = Math.max(max, score + dp(newList, memo));
            i = j;
        }

        memo.put(key, max);
        return max;
    }
}
