001
014
015 package com.liferay.portal.dao.shard;
016
017 import com.liferay.counter.service.persistence.CounterFinder;
018 import com.liferay.counter.service.persistence.CounterPersistence;
019 import com.liferay.portal.NoSuchCompanyException;
020 import com.liferay.portal.kernel.exception.PortalException;
021 import com.liferay.portal.kernel.exception.SystemException;
022 import com.liferay.portal.kernel.log.Log;
023 import com.liferay.portal.kernel.log.LogFactoryUtil;
024 import com.liferay.portal.kernel.util.InfrastructureUtil;
025 import com.liferay.portal.kernel.util.InitialThreadLocal;
026 import com.liferay.portal.kernel.util.StringPool;
027 import com.liferay.portal.kernel.util.StringUtil;
028 import com.liferay.portal.model.Company;
029 import com.liferay.portal.model.Shard;
030 import com.liferay.portal.security.auth.CompanyThreadLocal;
031 import com.liferay.portal.service.CompanyLocalServiceUtil;
032 import com.liferay.portal.service.ShardLocalServiceUtil;
033 import com.liferay.portal.service.persistence.ClassNamePersistence;
034 import com.liferay.portal.service.persistence.CompanyPersistence;
035 import com.liferay.portal.service.persistence.ReleasePersistence;
036 import com.liferay.portal.service.persistence.ShardPersistence;
037 import com.liferay.portal.util.PropsValues;
038
039 import java.util.EmptyStackException;
040 import java.util.HashMap;
041 import java.util.Map;
042 import java.util.Stack;
043
044 import javax.sql.DataSource;
045
046 import org.aspectj.lang.ProceedingJoinPoint;
047
048
052 public class ShardAdvice {
053
054 public void afterPropertiesSet() {
055 if (_shardDataSourceTargetSource == null) {
056 _shardDataSourceTargetSource =
057 (ShardDataSourceTargetSource)InfrastructureUtil.
058 getShardDataSourceTargetSource();
059 }
060
061 if (_shardSessionFactoryTargetSource == null) {
062 _shardSessionFactoryTargetSource =
063 (ShardSessionFactoryTargetSource)InfrastructureUtil.
064 getShardSessionFactoryTargetSource();
065 }
066 }
067
068 public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
069 throws Throwable {
070
071 Object[] arguments = proceedingJoinPoint.getArgs();
072
073 long companyId = (Long)arguments[0];
074
075 Shard shard = ShardLocalServiceUtil.getShard(
076 Company.class.getName(), companyId);
077
078 String shardName = shard.getName();
079
080 if (_log.isInfoEnabled()) {
081 _log.info(
082 "Service being set to shard " + shardName + " for " +
083 _getSignature(proceedingJoinPoint));
084 }
085
086 Object returnValue = null;
087
088 pushCompanyService(shardName);
089
090 try {
091 returnValue = proceedingJoinPoint.proceed();
092 }
093 finally {
094 popCompanyService();
095 }
096
097 return returnValue;
098 }
099
100 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
101 throws Throwable {
102
103 String methodName = proceedingJoinPoint.getSignature().getName();
104 Object[] arguments = proceedingJoinPoint.getArgs();
105
106 String shardName = PropsValues.SHARD_DEFAULT_NAME;
107
108 if (methodName.equals("addCompany")) {
109 String webId = (String)arguments[0];
110 String virtualHost = (String)arguments[1];
111 String mx = (String)arguments[2];
112 shardName = (String)arguments[3];
113
114 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
115
116 arguments[3] = shardName;
117 }
118 else if (methodName.equals("checkCompany")) {
119 String webId = (String)arguments[0];
120
121 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
122 if (arguments.length == 3) {
123 String mx = (String)arguments[1];
124 shardName = (String)arguments[2];
125
126 shardName = _getCompanyShardName(
127 webId, null, mx, shardName);
128
129 arguments[2] = shardName;
130 }
131
132 try {
133 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
134 webId);
135
136 shardName = company.getShardName();
137 }
138 catch (NoSuchCompanyException nsce) {
139 }
140 }
141 }
142 else if (methodName.startsWith("update")) {
143 long companyId = (Long)arguments[0];
144
145 Shard shard = ShardLocalServiceUtil.getShard(
146 Company.class.getName(), companyId);
147
148 shardName = shard.getName();
149 }
150 else {
151 return proceedingJoinPoint.proceed();
152 }
153
154 if (_log.isInfoEnabled()) {
155 _log.info(
156 "Company service being set to shard " + shardName + " for " +
157 _getSignature(proceedingJoinPoint));
158 }
159
160 Object returnValue = null;
161
162 pushCompanyService(shardName);
163
164 try {
165 returnValue = proceedingJoinPoint.proceed(arguments);
166 }
167 finally {
168 popCompanyService();
169 }
170
171 return returnValue;
172 }
173
174
180 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
181 throws Throwable {
182
183 _globalCall.set(new Object());
184
185 try {
186 if (_log.isInfoEnabled()) {
187 _log.info(
188 "All shards invoked for " +
189 _getSignature(proceedingJoinPoint));
190 }
191
192 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
193 _shardDataSourceTargetSource.setDataSource(shardName);
194 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
195
196 proceedingJoinPoint.proceed();
197 }
198 }
199 finally {
200 _globalCall.set(null);
201 }
202
203 return null;
204 }
205
206
212 public Object invokeIteratively(ProceedingJoinPoint proceedingJoinPoint)
213 throws Throwable {
214
215 if (_log.isInfoEnabled()) {
216 _log.info(
217 "Iterating through all shards for " +
218 _getSignature(proceedingJoinPoint));
219 }
220
221 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
222 pushCompanyService(shardName);
223
224 try {
225 proceedingJoinPoint.proceed();
226 }
227 finally {
228 popCompanyService();
229 }
230 }
231
232 return null;
233 }
234
235 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
236 throws Throwable {
237
238 if ((_shardDataSourceTargetSource == null) ||
239 (_shardSessionFactoryTargetSource == null)) {
240
241 return proceedingJoinPoint.proceed();
242 }
243
244 Object target = proceedingJoinPoint.getTarget();
245
246 if (target instanceof ClassNamePersistence ||
247 target instanceof CompanyPersistence ||
248 target instanceof CounterFinder ||
249 target instanceof CounterPersistence ||
250 target instanceof ReleasePersistence ||
251 target instanceof ShardPersistence) {
252
253 _shardDataSourceTargetSource.setDataSource(
254 PropsValues.SHARD_DEFAULT_NAME);
255 _shardSessionFactoryTargetSource.setSessionFactory(
256 PropsValues.SHARD_DEFAULT_NAME);
257
258 if (_log.isDebugEnabled()) {
259 _log.debug(
260 "Using default shard for " +
261 _getSignature(proceedingJoinPoint));
262 }
263
264 return proceedingJoinPoint.proceed();
265 }
266
267 if (_globalCall.get() == null) {
268 _setShardNameByCompany();
269
270 String shardName = _getShardName();
271
272 _shardDataSourceTargetSource.setDataSource(shardName);
273 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
274
275 if (_log.isInfoEnabled()) {
276 _log.info(
277 "Using shard name " + shardName + " for " +
278 _getSignature(proceedingJoinPoint));
279 }
280
281 return proceedingJoinPoint.proceed();
282 }
283 else {
284 return proceedingJoinPoint.proceed();
285 }
286 }
287
288 public void setShardDataSourceTargetSource(
289 ShardDataSourceTargetSource shardDataSourceTargetSource) {
290
291 _shardDataSourceTargetSource = shardDataSourceTargetSource;
292 }
293
294 public void setShardSessionFactoryTargetSource(
295 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
296
297 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
298 }
299
300 protected String getCurrentShardName() {
301 String shardName = null;
302
303 try {
304 shardName = _getCompanyServiceStack().peek();
305 }
306 catch (EmptyStackException ese) {
307 }
308
309 if (shardName == null) {
310 shardName = PropsValues.SHARD_DEFAULT_NAME;
311 }
312
313 return shardName;
314 }
315
316 protected DataSource getDataSource() {
317 return _shardDataSourceTargetSource.getDataSource();
318 }
319
320 protected String popCompanyService() {
321 return _getCompanyServiceStack().pop();
322 }
323
324 protected void pushCompanyService(long companyId) {
325 try {
326 Shard shard = ShardLocalServiceUtil.getShard(
327 Company.class.getName(), companyId);
328
329 String shardName = shard.getName();
330
331 pushCompanyService(shardName);
332 }
333 catch (Exception e) {
334 _log.error(e, e);
335 }
336 }
337
338 protected void pushCompanyService(String shardName) {
339 _getCompanyServiceStack().push(shardName);
340 }
341
342 private Stack<String> _getCompanyServiceStack() {
343 Stack<String> companyServiceStack = _companyServiceStack.get();
344
345 if (companyServiceStack == null) {
346 companyServiceStack = new Stack<String>();
347
348 _companyServiceStack.set(companyServiceStack);
349 }
350
351 return companyServiceStack;
352 }
353
354 private String _getCompanyShardName(
355 String webId, String virtualHost, String mx, String shardName) {
356
357 Map<String, String> shardParams = new HashMap<String, String>();
358
359 shardParams.put("webId", webId);
360 shardParams.put("mx", mx);
361
362 if (virtualHost != null) {
363 shardParams.put("virtualHost", virtualHost);
364 }
365
366 shardName = _shardSelector.getShardName(
367 ShardSelector.COMPANY_SCOPE, shardName, shardParams);
368
369 return shardName;
370 }
371
372 private String _getShardName() {
373 return _shardName.get();
374 }
375
376 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
377 String methodName = StringUtil.extractLast(
378 proceedingJoinPoint.getTarget().getClass().getName(),
379 StringPool.PERIOD);
380
381 methodName +=
382 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
383 "()";
384
385 return methodName;
386 }
387
388 private void _setShardName(String shardName) {
389 _shardName.set(shardName);
390 }
391
392 private void _setShardNameByCompany() throws Throwable {
393 Stack<String> companyServiceStack = _getCompanyServiceStack();
394
395 if (companyServiceStack.isEmpty()) {
396 long companyId = CompanyThreadLocal.getCompanyId();
397
398 _setShardNameByCompanyId(companyId);
399 }
400 else {
401 String shardName = companyServiceStack.peek();
402
403 _setShardName(shardName);
404 }
405 }
406
407 private void _setShardNameByCompanyId(long companyId)
408 throws PortalException, SystemException {
409
410 if (companyId == 0) {
411 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
412 }
413 else {
414 Shard shard = ShardLocalServiceUtil.getShard(
415 Company.class.getName(), companyId);
416
417 String shardName = shard.getName();
418
419 _setShardName(shardName);
420 }
421 }
422
423 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
424
425 private static ThreadLocal<Stack<String>> _companyServiceStack =
426 new ThreadLocal<Stack<String>>();
427 private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
428 private static ThreadLocal<String> _shardName =
429 new InitialThreadLocal<String>(
430 ShardAdvice.class + "._shardName", PropsValues.SHARD_DEFAULT_NAME);
431 private static ShardSelector _shardSelector;
432
433 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
434 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
435
436 static {
437 try {
438 _shardSelector = (ShardSelector)Class.forName(
439 PropsValues.SHARD_SELECTOR).newInstance();
440 }
441 catch (Exception e) {
442 _log.error(e, e);
443 }
444 }
445
446 }