001    /**
002     * Copyright (c) 2000-2012 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.io.Closeable;
021    import java.io.IOException;
022    
023    import java.util.HashMap;
024    import java.util.Map;
025    import java.util.concurrent.atomic.AtomicInteger;
026    
027    /**
028     * @author Shuyang Zhou
029     */
030    public class CentralizedThreadLocal<T> extends ThreadLocal<T> {
031    
032            public static void clearLongLivedThreadLocals() {
033                    _longLivedThreadLocals.remove();
034            }
035    
036            public static void clearShortLivedThreadLocals() {
037                    _shortLivedThreadLocals.remove();
038            }
039    
040            public static Map<CentralizedThreadLocal<?>, Object>
041                    getLongLivedThreadLocals() {
042    
043                    return _toMap(_longLivedThreadLocals.get());
044            }
045    
046            public static Map<CentralizedThreadLocal<?>, Object>
047                    getShortLivedThreadLocals() {
048    
049                    return _toMap(_shortLivedThreadLocals.get());
050            }
051    
052            public static void setThreadLocals(
053                    Map<CentralizedThreadLocal<?>, Object> longLivedThreadLocals,
054                    Map<CentralizedThreadLocal<?>, Object> shortLivedThreadLocals) {
055    
056                    ThreadLocalMap threadLocalMap = _longLivedThreadLocals.get();
057    
058                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
059                                    longLivedThreadLocals.entrySet()) {
060    
061                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
062                    }
063    
064                    threadLocalMap = _shortLivedThreadLocals.get();
065    
066                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
067                                    shortLivedThreadLocals.entrySet()) {
068    
069                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
070                    }
071            }
072    
073            public CentralizedThreadLocal(boolean shortLived) {
074                    _shortLived = shortLived;
075    
076                    if (shortLived) {
077                            _hashCode = _shortLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
078                    }
079                    else {
080                            _hashCode = _longLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
081                    }
082            }
083    
084            @Override
085            public T get() {
086                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
087    
088                    Entry entry = threadLocalMap.getEntry(this);
089    
090                    if (entry == null) {
091                            T value = initialValue();
092    
093                            threadLocalMap.putEntry(this, value);
094    
095                            return value;
096                    }
097                    else {
098                            return (T)entry._value;
099                    }
100            }
101    
102            @Override
103            public int hashCode() {
104                    return _hashCode;
105            }
106    
107            @Override
108            public void remove() {
109                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
110    
111                    threadLocalMap.removeEntry(this);
112            }
113    
114            @Override
115            public void set(T value) {
116                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
117    
118                    threadLocalMap.putEntry(this, value);
119            }
120    
121            private static Map<CentralizedThreadLocal<?>, Object> _toMap(
122                    ThreadLocalMap threadLocalMap) {
123    
124                    Map<CentralizedThreadLocal<?>, Object> map =
125                            new HashMap<CentralizedThreadLocal<?>, Object>(
126                                    threadLocalMap._table.length);
127    
128                    for (Entry entry : threadLocalMap._table) {
129                            map.put(entry._key, entry._value);
130                    }
131    
132                    return map;
133            }
134    
135            private ThreadLocalMap _getThreadLocalMap() {
136                    if (_shortLived) {
137                            return _shortLivedThreadLocals.get();
138                    }
139                    else {
140                            return _longLivedThreadLocals.get();
141                    }
142            }
143    
144            private static final int _HASH_INCREMENT = 0x61c88647;
145    
146            private static Log _log = LogFactoryUtil.getLog(
147                    CentralizedThreadLocal.class);
148    
149            private static final AtomicInteger _longLivedNextHasCode =
150                    new AtomicInteger();
151            private static final ThreadLocal<ThreadLocalMap> _longLivedThreadLocals =
152                    new ThreadLocalMapThreadLocal();
153            private static final AtomicInteger _shortLivedNextHasCode =
154                    new AtomicInteger();
155            private static final ThreadLocal<ThreadLocalMap> _shortLivedThreadLocals =
156                    new ThreadLocalMapThreadLocal();
157    
158            private final int _hashCode;
159            private final boolean _shortLived;
160    
161            private static class Entry {
162    
163                    public Entry(CentralizedThreadLocal<?> key, Object value, Entry next) {
164                            _key = key;
165                            _value = value;
166                            _next = next;
167                    }
168    
169                    private CentralizedThreadLocal<?> _key;
170                    private Entry _next;
171                    private Object _value;
172    
173            }
174    
175            private static class ThreadLocalMap {
176    
177                    public void expand(int newCapacity) {
178                            if (_table.length == _MAXIMUM_CAPACITY) {
179                                    _threshold = Integer.MAX_VALUE;
180    
181                                    return;
182                            }
183    
184                            Entry[] newTable = new Entry[newCapacity];
185    
186                            for (int i = 0; i < _table.length; i++) {
187                                    Entry entry = _table[i];
188    
189                                    if (entry == null) {
190                                            continue;
191                                    }
192    
193                                    _table[i] = null;
194    
195                                    do {
196                                            Entry nextEntry = entry._next;
197    
198                                            int index = entry._key._hashCode & (newCapacity - 1);
199    
200                                            entry._next = newTable[index];
201    
202                                            newTable[index] = entry;
203    
204                                            entry = nextEntry;
205                                    }
206                                    while (entry != null);
207                            }
208    
209                            _table = newTable;
210    
211                            _threshold = newCapacity * 2 / 3;
212                    }
213    
214                    public Entry getEntry(CentralizedThreadLocal<?> key) {
215                            int index = key._hashCode & (_table.length - 1);
216    
217                            Entry entry = _table[index];
218    
219                            if (entry == null) {
220                                    return null;
221                            }
222                            else if (entry._key == key) {
223                                    return entry;
224                            }
225                            else {
226                                    while ((entry = entry._next) != null) {
227                                            if (entry._key == key) {
228                                                    return entry;
229                                            }
230                                    }
231    
232                                    return null;
233                            }
234                    }
235    
236                    public void putEntry(CentralizedThreadLocal<?> key, Object value) {
237                            int index = key._hashCode & (_table.length - 1);
238    
239                            for (Entry entry = _table[index]; entry != null;
240                                    entry = entry._next) {
241    
242                                    if (entry._key == key) {
243                                            _closeEntry(entry._value);
244    
245                                            entry._value = value;
246    
247                                            return;
248                                    }
249                            }
250    
251                            _table[index] = new Entry(key, value, _table[index]);
252    
253                            if (_size++ >= _threshold) {
254                                    expand(2 * _table.length);
255                            }
256                    }
257    
258                    public void closeEntries() {
259                            for (Entry entry : _table) {
260                                    if (entry == null) {
261                                            continue;
262                                    }
263    
264                                    _closeEntry(entry._value);
265                            }
266                    }
267    
268                    public void removeEntry(CentralizedThreadLocal<?> key) {
269                            int index = key._hashCode & (_table.length - 1);
270    
271                            Entry previousEntry = null;
272    
273                            Entry entry = _table[index];
274    
275                            while (entry != null) {
276                                    Entry nextEntry = entry._next;
277    
278                                    if (entry._key == key) {
279                                            _size--;
280    
281                                            _closeEntry(entry._value);
282    
283                                            if (previousEntry == null) {
284                                                    _table[index] = nextEntry;
285                                            }
286                                            else {
287                                                    previousEntry._next = nextEntry;
288                                            }
289    
290                                            return;
291                                    }
292    
293                                    previousEntry = entry;
294                                    entry = nextEntry;
295                            }
296                    }
297    
298                    protected void _closeEntry(Object value) {
299                            if (value == null) {
300                                    return;
301                            }
302    
303                            if (value instanceof Closeable) {
304                                    Closeable closable = (Closeable)value;
305    
306                                    try {
307                                            closable.close();
308                                    }
309                                    catch (IOException ioe) {
310                                            _log.error(ioe, ioe);
311                                    }
312                            }
313                    }
314    
315                    private static final int _INITIAL_CAPACITY = 16;
316    
317                    private static final int _MAXIMUM_CAPACITY = 1 << 30;
318    
319                    private int _size;
320                    private Entry[] _table = new Entry[_INITIAL_CAPACITY];
321                    private int _threshold = _INITIAL_CAPACITY * 2 / 3;
322    
323            }
324    
325            private static class ThreadLocalMapThreadLocal
326                    extends ThreadLocal<ThreadLocalMap> {
327    
328                    @Override
329                    protected ThreadLocalMap initialValue() {
330                            return new ThreadLocalMap();
331                    }
332    
333                    @Override
334                    public void remove() {
335                            ThreadLocalMap threadLocalMap = get();
336    
337                            threadLocalMap.closeEntries();
338    
339                            super.remove();
340                    }
341    
342            }
343    
344    }