import { Equal, mapEqual } from "./equal";
import { iterableMap } from "./iterator";

export function getOrSet<K, V>(map: Map<K, V>, key: K, valueFn: () => V): V {
  if (map.has(key)) {
    return map.get(key)!;
  }
  const value = valueFn();
  map.set(key, value);
  return value;
}

export function aggregateBy<K, V, A>(
  iterable: Iterable<V>,
  key: (value: V) => K,
  aggregate: (acc: A, value: V) => A,
  initial: () => A,
): Map<K, A> {
  const result = new Map<K, A>();
  for (const value of iterable) {
    const k = key(value);
    let acc: A;
    if (result.has(k)) {
      acc = result.get(k)!;
    } else {
      acc = initial();
    }
    acc = aggregate(acc, value);
    result.set(k, acc);
  }
  return result;
}

/**
 * Combines maps A and B, overwriting any overlaps of the two with B.
 */
export function overwriteCombineMaps<K, V>(
  mapA: Map<K, V>,
  mapB: Map<K, V>,
): Map<K, V> {
  const result = new Map<K, V>(mapA);
  for (const [key, value] of mapB) {
    result.set(key, value);
  }
  return result;
}

export function groupBy<K, V>(
  iterable: Iterable<V>,
  key: (value: V) => K,
): Map<K, V[]> {
  const result = new Map<K, V[]>();
  for (const value of iterable) {
    const k = key(value);
    let values = result.get(k);
    if (!values) {
      values = [];
      result.set(k, values);
    }
    values.push(value);
  }
  return result;
}

export function groupMap<T, K, V>(
  iterable: Iterable<T>,
  entry: (value: T) => [K, V],
): Map<K, V> {
  const result = new Map<K, V>();
  for (const value of iterable) {
    const [k, v] = entry(value);
    result.set(k, v);
  }
  return result;
}

export function mapMapValues<K, A, B>(
  map: Map<K, A>,
  fn: (value: A, key: K) => B,
): Map<K, B> {
  return new Map(iterableMap(map, ([key, value]) => [key, fn(value, key)]));
}

export function memoize<K, V>(fn: (key: K) => V): (key: K) => V {
  const cache = new Map<K, V>();
  return (key) => getOrSet(cache, key, () => fn(key));
}

export function memoizeRemoveable<K, V>(
  fn: (key: K, remove: () => void) => V,
): (key: K) => V {
  const cache = new Map<K, V>();
  return (key) => getOrSet(cache, key, () => fn(key, () => cache.delete(key)));
}

export class ImmutableMap<K, V> {
  private constructor(private readonly map: Map<K, V>) {}

  get size(): number {
    return this.map.size;
  }

  delete(key: K): ImmutableMap<K, V> {
    const newMap = new Map(this.map);
    newMap.delete(key);
    return new ImmutableMap(newMap);
  }

  entries(): Iterable<[K, V]> {
    return this.map.entries();
  }

  get(key: K): V | undefined {
    return this.map.get(key);
  }

  keys(): Iterable<K> {
    return this.map.keys();
  }

  set(key: K, value: V): ImmutableMap<K, V> {
    const newMap = new Map(this.map);
    newMap.set(key, value);
    return new ImmutableMap(newMap);
  }

  forEach(callbackfn: (value: V, key: K, map: Map<K, V>) => void): void {
    this.map.forEach(callbackfn);
  }

  toMap(): Map<K, V> {
    return new Map(this.map);
  }

  values(): Iterable<V> {
    return this.map.values();
  }

  [Symbol.iterator](): IterableIterator<[K, V]> {
    return this.map[Symbol.iterator]();
  }

  private static readonly EMPTY = new this(new Map());

  static empty<K, V>(): ImmutableMap<K, V> {
    return this.EMPTY;
  }

  static fromEntries<K, V>(entries: Iterable<[K, V]>): ImmutableMap<K, V> {
    return new ImmutableMap(new Map(entries));
  }

  static fromObject<K extends string, V>(
    object: Record<K, V>,
  ): ImmutableMap<K, V> {
    return this.fromEntries(<[K, V][]>Object.entries(object));
  }

  static of<K, V>(...entries: [K, V][]): ImmutableMap<K, V> {
    return new ImmutableMap(new Map(entries));
  }
}

export function immutableMapEqual<K, V>(
  valueEquals: Equal<V>,
): Equal<ImmutableMap<K, V>> {
  const equal = mapEqual(valueEquals);
  return (a, b) => equal(a.toMap(), b.toMap());
}
