Skip to content

Map 基本实现及常见问题

类图:

  • HashMap 无需的键值对,常用
  • TreeMap 有序的键值对
  • Hashtable 线程安全版的 HashMap,在方法上加 synchronized 关键字
  • WeakHashMap 弱键版的 HashMap

HashMap

常用 API

java
@Test
public void testHashMap() {
    HashMap<String, String> map = new HashMap<>();

    map.put("a", "0");
    map.put("b", "1");
    map.put("c", "2");

    String a = map.get("a");
    System.out.println(a);

    Set<String> keys = map.keySet();
    System.out.println(keys);

    Collection<String> values = map.values();
    System.out.println(values);

    Set<Map.Entry<String, String>> entries = map.entrySet();
    System.out.println(entries);
}

下面从这些 API 入手读懂 HashMap 的核心源码

数据结构与构造函数

HashMap 以 拉链法 存储数据:内置一个 Node 数组,put 键值对时,计算 key 的哈希值,并已数组长度取模,得到该键值对的索引,创建节点放入数组对应位置,当出现哈希碰撞(key 的哈希取模后相等,但 key 不相等)时,将该节点存储为已有节点的下一个节点,即组成链表,如果链表长度超过 8 则将链表转为红黑树。

成员变量

java
// Node 数组,即存放节点的桶
transient Node<K,V>[] table;
// 节点视图,主要用于遍历
transient Set<Map.Entry<K,V>> entrySet;
// 实际存储键值对的数量
transient int size;
// 改变内部结构时 +1,用于 fail-fast 机制,检测并发错误
transient int modCount;
// 容量阈值,容量超过该值时扩容,其实就是 map 的容量
int threshold;
// 负载因子,决定空间利用率
final float loadFactor;

构造函数

java
// 无参构造,成员变量都使用默认值
public HashMap() {
    this.loadFactor = DEFAULT_LOAD_FACTOR;
}
// 提供初始容量
public HashMap(int initialCapacity) {
    this(initialCapacity, DEFAULT_LOAD_FACTOR);
}
// 提供初始容量和负载因子
public HashMap(int initialCapacity, float loadFactor) {
    if (initialCapacity < 0)
        throw new IllegalArgumentException("Illegal initial capacity: " +
                                            initialCapacity);
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    if (loadFactor <= 0 || Float.isNaN(loadFactor))
        throw new IllegalArgumentException("Illegal load factor: " +
                                            loadFactor);
    this.loadFactor = loadFactor;
    // 计算出不小于 initialCapacity 的 2 的次方作为阈值
    this.threshold = tableSizeFor(initialCapacity);
}
static final int tableSizeFor(int cap) {
    int n = cap - 1;
    n |= n >>> 1;
    n |= n >>> 2;
    n |= n >>> 4;
    n |= n >>> 8;
    n |= n >>> 16;
    return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
}

// 提供现有map的数据
public HashMap(Map<? extends K, ? extends V> m) {
    this.loadFactor = DEFAULT_LOAD_FACTOR;
    putMapEntries(m, false);
}
final void putMapEntries(Map<? extends K, ? extends V> m, boolean evict) {
    int s = m.size();
    if (s > 0) {
        if (table == null) { // pre-size
            // 初始化 table
            // 根据提供的 map 的元素数量,配合负载因子计算所需容量
            float ft = ((float)s / loadFactor) + 1.0F;
            int t = ((ft < (float)MAXIMUM_CAPACITY) ?
                        (int)ft : MAXIMUM_CAPACITY);
            if (t > threshold)
                threshold = tableSizeFor(t);
        }
        else if (s > threshold)
            // 超过阈值,扩容
            resize();
        for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) {
            K key = e.getKey();
            V value = e.getValue();
            putVal(hash(key), key, value, false, evict);
        }
    }
}

map.put(key, value) 添加数据源码

java
// put 入口方法
public V put(K key, V value) {
    return putVal(hash(key), key, value, false, true);
}
// 计算 hash 值
static final int hash(Object key) {
    // 这里为什么用 key 的 hashCode 与 右移 16 位后的 hashCode 做异或运算
    // 1. 不直接使用 key 的 hashCode 是为了使 hash 值更复杂,分布更均匀
    // 2. 为了在 tab[(n - 1) & hash] 给 hash 与 数组长度做取模运算时,一般数组长度较小,hash 高位可能没参与运算,为了让 hash 的高位与低位都参与运算,所以做了右移和异或运算
    int h;
    return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
}
// put 操作
final V putVal(int hash, K key, V value, boolean onlyIfAbsent, boolean evict) {
    Node<K,V>[] tab; Node<K,V> p; int n, i;
    if ((tab = table) == null || (n = tab.length) == 0)
        // table 为空时需要初始化,见下面 resize() 方法
        n = (tab = resize()).length;
    if ((p = tab[i = (n - 1) & hash]) == null)
        // 桶位置没有节点则创建节点放在该位置
        tab[i] = newNode(hash, key, value, null);
    else {
        // 该桶位置存在节点
        Node<K,V> e; K k;
        if (p.hash == hash &&
            ((k = p.key) == key || (key != null && key.equals(k))))
            // 新 key 与首节点 key 相同
            e = p;
        else if (p instanceof TreeNode)
            // 桶里是棵树
            e = ((TreeNode<K,V>)p).putTreeVal(this, tab, hash, key, value);
        else {
            // 链表
            for (int binCount = 0; ; ++binCount) {
                if ((e = p.next) == null) {
                    // 循环走到了最后一个节点,就在最后添加节点
                    p.next = newNode(hash, key, value, null);
                    if (binCount >= TREEIFY_THRESHOLD - 1) // -1 for 1st
                        // 如果满足条件就转成树 TODO
                        treeifyBin(tab, hash);
                    break;
                }
                if (e.hash == hash &&
                    ((k = e.key) == key || (key != null && key.equals(k))))
                    // key 与链表上某个节点的 key 相同
                    break;
                p = e;
            }
        }
        if (e != null) { // existing mapping for key
            V oldValue = e.value;
            if (!onlyIfAbsent || oldValue == null)
                // 满足条件(key存在)则更新 value
                e.value = value;
            afterNodeAccess(e);
            return oldValue;
        }
    }
    ++modCount;
    if (++size > threshold)
        // 超过阈值则扩容
        resize();
    afterNodeInsertion(evict);
    return null;
}
// 核心方法,数组初始化/扩容
final Node<K,V>[] resize() {
    Node<K,V>[] oldTab = table;
    int oldCap = (oldTab == null) ? 0 : oldTab.length;
    int oldThr = threshold;
    int newCap, newThr = 0;
    // 计算新容量
    if (oldCap > 0) {
        if (oldCap >= MAXIMUM_CAPACITY) {
            // 数组已经最大,没法扩了
            threshold = Integer.MAX_VALUE;
            return oldTab;
        }
        else if ((newCap = oldCap << 1) < MAXIMUM_CAPACITY &&
                    oldCap >= DEFAULT_INITIAL_CAPACITY)
            // 阈值翻倍
            newThr = oldThr << 1; // double threshold
    }
    else if (oldThr > 0) // initial capacity was placed in threshold
        // 旧的数组容量为0,阈值大于0,则数组新容量更新为老阈值大小,这个场景是自定义了初始容量,最后数组容量跟阈值相等了,负载因子在扩容的时候才起作用
        newCap = oldThr;
    else {               // zero initial threshold signifies using defaults
        // 旧数组容量为0且阈值也为0,说明还没初始化,并且没有自定义容量和负载因子
        // 则数组容量设置为默认初始容量,新阈值更新为 默认初始容量 * 负载因子
        newCap = DEFAULT_INITIAL_CAPACITY;
        newThr = (int)(DEFAULT_LOAD_FACTOR * DEFAULT_INITIAL_CAPACITY);
    }
    if (newThr == 0) {
        // 如果阈值还没初始化,则更新为 新容量 * 负载因子
        float ft = (float)newCap * loadFactor;
        newThr = (newCap < MAXIMUM_CAPACITY && ft < (float)MAXIMUM_CAPACITY ?
                    (int)ft : Integer.MAX_VALUE);
    }
    threshold = newThr;
    @SuppressWarnings({"rawtypes","unchecked"})
    Node<K,V>[] newTab = (Node<K,V>[])new Node[newCap];
    table = newTab;
    if (oldTab != null) {
        // 老数组迁移到新数组
        for (int j = 0; j < oldCap; ++j) {
            Node<K,V> e;
            if ((e = oldTab[j]) != null) {
                // 老数组中的桶赋值为 null,可能是为了帮助 GC
                oldTab[j] = null;
                if (e.next == null)
                    // 如果该桶是单节点,直接赋值到新数组的对应位置,索引为 hash 取模
                    newTab[e.hash & (newCap - 1)] = e;
                else if (e instanceof TreeNode)
                    // 如果该桶点是一棵树
                    ((TreeNode<K,V>)e).split(this, newTab, j, oldCap);
                else { // preserve order
                    // 该桶是多节点链表,迁移
                    Node<K,V> loHead = null, loTail = null;
                    Node<K,V> hiHead = null, hiTail = null;
                    Node<K,V> next;
                    do {
                        next = e.next;
                        // 同一个桶里的数据最多会被拆分为两个桶(重点)
                        if ((e.hash & oldCap) == 0) {
                            if (loTail == null)
                                loHead = e;
                            else
                                loTail.next = e;
                            loTail = e;
                        }
                        else {
                            if (hiTail == null)
                                hiHead = e;
                            else
                                hiTail.next = e;
                            hiTail = e;
                        }
                    } while ((e = next) != null);
                    // 放入新数组
                    if (loTail != null) {
                        loTail.next = null;
                        newTab[j] = loHead;
                    }
                    if (hiTail != null) {
                        hiTail.next = null;
                        newTab[j + oldCap] = hiHead;
                    }
                }
            }
        }
    }
    return newTab;
}

map.get(key) 查找数据源码

java
// 查找
public V get(Object key) {
    Node<K,V> e;
    // 具体查找 node
    return (e = getNode(hash(key), key)) == null ? null : e.value;
}
// 查找 node
final Node<K,V> getNode(int hash, Object key) {
    Node<K,V>[] tab; Node<K,V> first, e; int n; K k;
    // 首节点有值
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (first = tab[(n - 1) & hash]) != null) {
        if (first.hash == hash && // always check first node
            ((k = first.key) == key || (key != null && key.equals(k))))
            // 检查首节点
            return first;
        if ((e = first.next) != null) {
            if (first instanceof TreeNode)
                // 桶里是树
                return ((TreeNode<K,V>)first).getTreeNode(hash, key);
            do {
                // 遍历链表检查节点的 key
                if (e.hash == hash &&
                    ((k = e.key) == key || (key != null && key.equals(k))))
                    return e;
            } while ((e = e.next) != null);
        }
    }
    return null;
}

map.remove(key) 移除数据源码

java
// 移除
public V remove(Object key) {
    Node<K,V> e;
    return (e = removeNode(hash(key), key, null, false, true)) == null ?
        null : e.value;
}
// 查找并移除节点
final Node<K,V> removeNode(int hash, Object key, Object value,
                            boolean matchValue, boolean movable) {
    Node<K,V>[] tab; Node<K,V> p; int n, index;
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (p = tab[index = (n - 1) & hash]) != null) {
        Node<K,V> node = null, e; K k; V v;
        if (p.hash == hash &&
            ((k = p.key) == key || (key != null && key.equals(k))))
            // 首节点就是了
            node = p;
        else if ((e = p.next) != null) {
            if (p instanceof TreeNode)
                // 从树里查找
                node = ((TreeNode<K,V>)p).getTreeNode(hash, key);
            else {
                // 遍历链表查找
                do {
                    if (e.hash == hash &&
                        ((k = e.key) == key ||
                            (key != null && key.equals(k)))) {
                        node = e;
                        break;
                    }
                    p = e;
                } while ((e = e.next) != null);
            }
        }
        if (node != null && (!matchValue || (v = node.value) == value ||
                                (value != null && value.equals(v)))) {
            if (node instanceof TreeNode)
                // 移除树上的节点
                ((TreeNode<K,V>)node).removeTreeNode(this, tab, movable);
            else if (node == p)
                // 是首节点,移除首节点
                tab[index] = node.next;
            else
                // 移除移除非首节点
                p.next = node.next;
            ++modCount;
            --size;
            afterNodeRemoval(node);
            return node;
        }
    }
    return null;
}