001    /**
002     * Copyright (c) 2000-2012 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.io.IOException;
021    
022    import java.lang.ref.WeakReference;
023    import java.lang.reflect.InvocationTargetException;
024    import java.lang.reflect.Method;
025    
026    import java.net.URL;
027    
028    import java.util.ArrayList;
029    import java.util.Collection;
030    import java.util.Collections;
031    import java.util.Enumeration;
032    import java.util.Iterator;
033    import java.util.List;
034    
035    /**
036     * @author Brian Wing Shun Chan
037     * @author Michael C. Han
038     * @author Shuyang Zhou
039     */
040    public class AggregateClassLoader extends ClassLoader {
041    
042            public static ClassLoader getAggregateClassLoader(
043                    ClassLoader parentClassLoader, ClassLoader[] classLoaders) {
044    
045                    if ((classLoaders == null) || (classLoaders.length == 0)) {
046                            return null;
047                    }
048    
049                    if (classLoaders.length == 1) {
050                            return classLoaders[0];
051                    }
052    
053                    AggregateClassLoader aggregateClassLoader = new AggregateClassLoader(
054                            parentClassLoader);
055    
056                    for (ClassLoader classLoader : classLoaders) {
057                            aggregateClassLoader.addClassLoader(classLoader);
058                    }
059    
060                    return aggregateClassLoader;
061            }
062    
063            public static ClassLoader getAggregateClassLoader(
064                    ClassLoader[] classLoaders) {
065    
066                    if ((classLoaders == null) || (classLoaders.length == 0)) {
067                            return null;
068                    }
069    
070                    return getAggregateClassLoader(classLoaders[0], classLoaders);
071            }
072    
073            public AggregateClassLoader(ClassLoader classLoader) {
074                    _parentClassLoaderReference = new WeakReference<ClassLoader>(
075                            classLoader);
076            }
077    
078            public void addClassLoader(ClassLoader classLoader) {
079                    if (getClassLoaders().contains(classLoader)) {
080                            return;
081                    }
082    
083                    if ((classLoader instanceof AggregateClassLoader) &&
084                            (classLoader.getParent().equals(getParent()))) {
085    
086                            AggregateClassLoader aggregateClassLoader =
087                                    (AggregateClassLoader)classLoader;
088    
089                            for (ClassLoader curClassLoader :
090                                            aggregateClassLoader.getClassLoaders()) {
091    
092                                    addClassLoader(curClassLoader);
093                            }
094                    }
095                    else {
096                            _classLoaderReferences.add(
097                                    new WeakReference<ClassLoader>(classLoader));
098                    }
099            }
100    
101            public void addClassLoader(ClassLoader... classLoaders) {
102                    for (ClassLoader classLoader : classLoaders) {
103                            addClassLoader(classLoader);
104                    }
105            }
106    
107            public void addClassLoader(Collection<ClassLoader> classLoaders) {
108                    for (ClassLoader classLoader : classLoaders) {
109                            addClassLoader(classLoader);
110                    }
111            }
112    
113            @Override
114            public boolean equals(Object obj) {
115                    if (this == obj) {
116                            return true;
117                    }
118    
119                    if (!(obj instanceof AggregateClassLoader)) {
120                            return false;
121                    }
122    
123                    AggregateClassLoader aggregateClassLoader = (AggregateClassLoader)obj;
124    
125                    if (_classLoaderReferences.equals(
126                                    aggregateClassLoader._classLoaderReferences) &&
127                            (((getParent() == null) &&
128                              (aggregateClassLoader.getParent() == null)) ||
129                             ((getParent() != null) &&
130                              (getParent().equals(aggregateClassLoader.getParent()))))) {
131    
132                            return true;
133                    }
134    
135                    return false;
136            }
137    
138            public List<ClassLoader> getClassLoaders() {
139                    List<ClassLoader> classLoaders = new ArrayList<ClassLoader>(
140                            _classLoaderReferences.size());
141    
142                    Iterator<WeakReference<ClassLoader>> itr =
143                            _classLoaderReferences.iterator();
144    
145                    while (itr.hasNext()) {
146                            WeakReference<ClassLoader> weakReference = itr.next();
147    
148                            ClassLoader classLoader = weakReference.get();
149    
150                            if (classLoader == null) {
151                                    itr.remove();
152                            }
153                            else {
154                                    classLoaders.add(classLoader);
155                            }
156                    }
157    
158                    return classLoaders;
159            }
160    
161            @Override
162            public URL getResource(String name) {
163                    for (ClassLoader classLoader : getClassLoaders()) {
164                            URL url = _getResource(classLoader, name);
165    
166                            if (url != null) {
167                                    return url;
168                            }
169                    }
170    
171                    ClassLoader parentClassLoader = _parentClassLoaderReference.get();
172    
173                    if (parentClassLoader == null) {
174                            return null;
175                    }
176    
177                    return parentClassLoader.getResource(name);
178            }
179    
180            @Override
181            public Enumeration<URL> getResources(String name)
182                    throws IOException {
183    
184                    List<URL> urls = new ArrayList<URL>();
185    
186                    for (ClassLoader classLoader : getClassLoaders()) {
187                            urls.addAll(Collections.list(_getResources(classLoader, name)));
188                    }
189    
190                    ClassLoader parentClassLoader = _parentClassLoaderReference.get();
191    
192                    if (parentClassLoader != null) {
193                            urls.addAll(
194                                    Collections.list(_getResources(parentClassLoader, name)));
195                    }
196    
197                    return Collections.enumeration(urls);
198            }
199    
200            @Override
201            public int hashCode() {
202                    if (_classLoaderReferences != null) {
203                            return _classLoaderReferences.hashCode();
204                    }
205                    else {
206                            return 0;
207                    }
208            }
209    
210            @Override
211            protected Class<?> findClass(String name) throws ClassNotFoundException {
212                    for (ClassLoader classLoader : getClassLoaders()) {
213                            try {
214                                    return _findClass(classLoader, name);
215                            }
216                            catch (ClassNotFoundException cnfe) {
217                            }
218                    }
219    
220                    throw new ClassNotFoundException("Unable to find class " + name);
221            }
222    
223            @Override
224            protected synchronized Class<?> loadClass(String name, boolean resolve)
225                    throws ClassNotFoundException {
226    
227                    Class<?> loadedClass = null;
228    
229                    for (ClassLoader classLoader : getClassLoaders()) {
230                            try {
231                                    loadedClass = _loadClass(classLoader, name, resolve);
232    
233                                    break;
234                            }
235                            catch (ClassNotFoundException cnfe) {
236                            }
237                    }
238    
239                    if (loadedClass == null) {
240                            ClassLoader parentClassLoader = _parentClassLoaderReference.get();
241    
242                            if (parentClassLoader == null) {
243                                    throw new ClassNotFoundException(
244                                            "Parent class loader has been garbage collected");
245                            }
246    
247                            loadedClass = _loadClass(parentClassLoader, name, resolve);
248                    }
249                    else if (resolve) {
250                            resolveClass(loadedClass);
251                    }
252    
253                    return loadedClass;
254            }
255    
256            private static Class<?> _findClass(ClassLoader classLoader, String name)
257                    throws ClassNotFoundException {
258    
259                    try {
260                            return (Class<?>) _findClassMethod.invoke(classLoader, name);
261                    }
262                    catch (InvocationTargetException ite) {
263                            throw new ClassNotFoundException(
264                                    "Unable to find class " + name, ite.getTargetException());
265                    }
266                    catch (Exception e) {
267                            throw new ClassNotFoundException("Unable to find class " + name, e);
268                    }
269            }
270    
271            private static URL _getResource(ClassLoader classLoader, String name) {
272                    try {
273                            return (URL)_getResourceMethod.invoke(classLoader, name);
274                    }
275                    catch (InvocationTargetException ite) {
276                            return null;
277                    }
278                    catch (Exception e) {
279                            return null;
280                    }
281            }
282    
283            private static Enumeration<URL> _getResources(
284                            ClassLoader classLoader, String name)
285                    throws IOException {
286    
287                    try {
288                            return (Enumeration<URL>)_getResourcesMethod.invoke(
289                                    classLoader, name);
290                    }
291                    catch (InvocationTargetException ite) {
292                            Throwable t = ite.getTargetException();
293    
294                            throw new IOException(t.getMessage());
295                    }
296                    catch (Exception e) {
297                            throw new IOException(e.getMessage());
298                    }
299            }
300    
301            private static Class<?> _loadClass(
302                            ClassLoader classLoader, String name, boolean resolve)
303                    throws ClassNotFoundException {
304    
305                    try {
306                            return (Class<?>) _loadClassMethod.invoke(
307                                    classLoader, name, resolve);
308                    }
309                    catch (InvocationTargetException ite) {
310                            throw new ClassNotFoundException(
311                                    "Unable to load class " + name, ite.getTargetException());
312                    }
313                    catch (Exception e) {
314                            throw new ClassNotFoundException("Unable to load class " + name, e);
315                    }
316            }
317    
318            private static Log _log = LogFactoryUtil.getLog(AggregateClassLoader.class);
319    
320            private static Method _findClassMethod;
321            private static Method _getResourceMethod;
322            private static Method _getResourcesMethod;
323            private static Method _loadClassMethod;
324    
325            private List<WeakReference<ClassLoader>> _classLoaderReferences =
326                    new ArrayList<WeakReference<ClassLoader>>();
327            private WeakReference<ClassLoader> _parentClassLoaderReference;
328    
329            static {
330                    try {
331                            _findClassMethod = ReflectionUtil.getDeclaredMethod(
332                                    ClassLoader.class, "findClass", String.class);
333                            _getResourceMethod = ReflectionUtil.getDeclaredMethod(
334                                    ClassLoader.class, "getResource", String.class);
335                            _getResourcesMethod = ReflectionUtil.getDeclaredMethod(
336                                    ClassLoader.class, "getResources", String.class);
337                            _loadClassMethod = ReflectionUtil.getDeclaredMethod(
338                                    ClassLoader.class, "loadClass", String.class, boolean.class);
339                    }
340                    catch (Exception e) {
341                            if (_log.isErrorEnabled()) {
342                                    _log.error("Unable to locate required methods", e);
343                            }
344                    }
345            }
346    
347    }