001/*- 002 * #%L 003 * HAPI FHIR - Core Library 004 * %% 005 * Copyright (C) 2014 - 2023 Smile CDR, Inc. 006 * %% 007 * Licensed under the Apache License, Version 2.0 (the "License"); 008 * you may not use this file except in compliance with the License. 009 * You may obtain a copy of the License at 010 * 011 * http://www.apache.org/licenses/LICENSE-2.0 012 * 013 * Unless required by applicable law or agreed to in writing, software 014 * distributed under the License is distributed on an "AS IS" BASIS, 015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 016 * See the License for the specific language governing permissions and 017 * limitations under the License. 018 * #L% 019 */ 020package ca.uhn.fhir.interceptor.executor; 021 022import ca.uhn.fhir.i18n.Msg; 023import ca.uhn.fhir.interceptor.api.HookParams; 024import ca.uhn.fhir.interceptor.api.IBaseInterceptorBroadcaster; 025import ca.uhn.fhir.interceptor.api.IBaseInterceptorService; 026import ca.uhn.fhir.interceptor.api.IPointcut; 027import ca.uhn.fhir.interceptor.api.Interceptor; 028import ca.uhn.fhir.interceptor.api.Pointcut; 029import ca.uhn.fhir.rest.server.exceptions.InternalErrorException; 030import ca.uhn.fhir.util.ReflectionUtil; 031import com.google.common.annotations.VisibleForTesting; 032import com.google.common.collect.ArrayListMultimap; 033import com.google.common.collect.ListMultimap; 034import org.apache.commons.lang3.Validate; 035import org.apache.commons.lang3.builder.ToStringBuilder; 036import org.apache.commons.lang3.builder.ToStringStyle; 037import org.apache.commons.lang3.reflect.MethodUtils; 038import org.slf4j.Logger; 039import org.slf4j.LoggerFactory; 040 041import javax.annotation.Nonnull; 042import javax.annotation.Nullable; 043import java.lang.annotation.Annotation; 044import java.lang.reflect.AnnotatedElement; 045import java.lang.reflect.InvocationTargetException; 046import java.lang.reflect.Method; 047import java.util.ArrayList; 048import java.util.Arrays; 049import java.util.Collection; 050import java.util.Collections; 051import java.util.Comparator; 052import java.util.EnumSet; 053import java.util.HashMap; 054import java.util.IdentityHashMap; 055import java.util.List; 056import java.util.Map; 057import java.util.Objects; 058import java.util.Optional; 059import java.util.concurrent.atomic.AtomicInteger; 060import java.util.function.Predicate; 061import java.util.stream.Collectors; 062 063public abstract class BaseInterceptorService<POINTCUT extends Enum<POINTCUT> & IPointcut> implements IBaseInterceptorService<POINTCUT>, IBaseInterceptorBroadcaster<POINTCUT> { 064 private static final Logger ourLog = LoggerFactory.getLogger(BaseInterceptorService.class); 065 private final List<Object> myInterceptors = new ArrayList<>(); 066 private final ListMultimap<POINTCUT, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create(); 067 private final ListMultimap<POINTCUT, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create(); 068 private final Object myRegistryMutex = new Object(); 069 private final Class<POINTCUT> myPointcutType; 070 private volatile EnumSet<POINTCUT> myRegisteredPointcuts; 071 private String myName; 072 private boolean myWarnOnInterceptorWithNoHooks = true; 073 074 /** 075 * Constructor which uses a default name of "default" 076 */ 077 public BaseInterceptorService(Class<POINTCUT> thePointcutType) { 078 this(thePointcutType, "default"); 079 } 080 081 /** 082 * Constructor 083 * 084 * @param theName The name for this registry (useful for troubleshooting) 085 */ 086 public BaseInterceptorService(Class<POINTCUT> thePointcutType, String theName) { 087 super(); 088 myName = theName; 089 myPointcutType = thePointcutType; 090 rebuildRegisteredPointcutSet(); 091 } 092 093 /** 094 * Should a warning be issued if an interceptor is registered and it has no hooks 095 */ 096 public void setWarnOnInterceptorWithNoHooks(boolean theWarnOnInterceptorWithNoHooks) { 097 myWarnOnInterceptorWithNoHooks = theWarnOnInterceptorWithNoHooks; 098 } 099 100 @VisibleForTesting 101 List<Object> getGlobalInterceptorsForUnitTest() { 102 return myInterceptors; 103 } 104 105 public void setName(String theName) { 106 myName = theName; 107 } 108 109 protected void registerAnonymousInterceptor(POINTCUT thePointcut, Object theInterceptor, BaseInvoker theInvoker) { 110 Validate.notNull(thePointcut); 111 Validate.notNull(theInterceptor); 112 synchronized (myRegistryMutex) { 113 114 myAnonymousInvokers.put(thePointcut, theInvoker); 115 if (!isInterceptorAlreadyRegistered(theInterceptor)) { 116 myInterceptors.add(theInterceptor); 117 } 118 119 rebuildRegisteredPointcutSet(); 120 } 121 } 122 123 @Override 124 public List<Object> getAllRegisteredInterceptors() { 125 synchronized (myRegistryMutex) { 126 List<Object> retVal = new ArrayList<>(myInterceptors); 127 return Collections.unmodifiableList(retVal); 128 } 129 } 130 131 @Override 132 @VisibleForTesting 133 public void unregisterAllInterceptors() { 134 synchronized (myRegistryMutex) { 135 unregisterInterceptors(myAnonymousInvokers.values()); 136 unregisterInterceptors(myGlobalInvokers.values()); 137 unregisterInterceptors(myInterceptors); 138 } 139 } 140 141 @Override 142 public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) { 143 if (theInterceptors != null) { 144 // We construct a new list before iterating because the service's internal 145 // interceptor lists get passed into this method, and we get concurrent 146 // modification errors if we modify them at the same time as we iterate them 147 new ArrayList<>(theInterceptors).forEach(this::unregisterInterceptor); 148 } 149 } 150 151 @Override 152 public void registerInterceptors(@Nullable Collection<?> theInterceptors) { 153 if (theInterceptors != null) { 154 theInterceptors.forEach(this::registerInterceptor); 155 } 156 } 157 158 @Override 159 public void unregisterAllAnonymousInterceptors() { 160 synchronized (myRegistryMutex) { 161 unregisterInterceptorsIf(t -> true, myAnonymousInvokers); 162 } 163 } 164 165 @Override 166 public void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction) { 167 unregisterInterceptorsIf(theShouldUnregisterFunction, myGlobalInvokers); 168 unregisterInterceptorsIf(theShouldUnregisterFunction, myAnonymousInvokers); 169 } 170 171 private void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction, ListMultimap<POINTCUT, BaseInvoker> theGlobalInvokers) { 172 synchronized (myRegistryMutex) { 173 for (Map.Entry<POINTCUT, BaseInvoker> nextInvoker : new ArrayList<>(theGlobalInvokers.entries())) { 174 if (theShouldUnregisterFunction.test(nextInvoker.getValue().getInterceptor())) { 175 unregisterInterceptor(nextInvoker.getValue().getInterceptor()); 176 } 177 } 178 179 rebuildRegisteredPointcutSet(); 180 } 181 } 182 183 @Override 184 public boolean registerInterceptor(Object theInterceptor) { 185 synchronized (myRegistryMutex) { 186 187 if (isInterceptorAlreadyRegistered(theInterceptor)) { 188 return false; 189 } 190 191 List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers); 192 if (addedInvokers.isEmpty()) { 193 if (myWarnOnInterceptorWithNoHooks) { 194 ourLog.warn("Interceptor registered with no valid hooks - Type was: {}", theInterceptor.getClass().getName()); 195 } 196 return false; 197 } 198 199 // Add to the global list 200 myInterceptors.add(theInterceptor); 201 sortByOrderAnnotation(myInterceptors); 202 203 rebuildRegisteredPointcutSet(); 204 205 return true; 206 } 207 } 208 209 private void rebuildRegisteredPointcutSet() { 210 EnumSet<POINTCUT> registeredPointcuts = EnumSet.noneOf(myPointcutType); 211 registeredPointcuts.addAll(myAnonymousInvokers.keySet()); 212 registeredPointcuts.addAll(myGlobalInvokers.keySet()); 213 myRegisteredPointcuts = registeredPointcuts; 214 } 215 216 private boolean isInterceptorAlreadyRegistered(Object theInterceptor) { 217 for (Object next : myInterceptors) { 218 if (next == theInterceptor) { 219 return true; 220 } 221 } 222 return false; 223 } 224 225 @Override 226 public boolean unregisterInterceptor(Object theInterceptor) { 227 synchronized (myRegistryMutex) { 228 boolean removed = myInterceptors.removeIf(t -> t == theInterceptor); 229 removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 230 removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 231 rebuildRegisteredPointcutSet(); 232 return removed; 233 } 234 } 235 236 private void sortByOrderAnnotation(List<Object> theObjects) { 237 IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>(); 238 for (Object next : theObjects) { 239 Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class); 240 int order = orderAnnotation != null ? orderAnnotation.order() : 0; 241 interceptorToOrder.put(next, order); 242 } 243 244 theObjects.sort((a, b) -> { 245 Integer orderA = interceptorToOrder.get(a); 246 Integer orderB = interceptorToOrder.get(b); 247 return orderA - orderB; 248 }); 249 } 250 251 @Override 252 public Object callHooksAndReturnObject(POINTCUT thePointcut, HookParams theParams) { 253 assert haveAppropriateParams(thePointcut, theParams); 254 assert thePointcut.getReturnType() != void.class; 255 256 return doCallHooks(thePointcut, theParams, null); 257 } 258 259 @Override 260 public boolean hasHooks(POINTCUT thePointcut) { 261 return myRegisteredPointcuts.contains(thePointcut); 262 } 263 264 @Override 265 public boolean callHooks(POINTCUT thePointcut, HookParams theParams) { 266 assert haveAppropriateParams(thePointcut, theParams); 267 assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == boolean.class; 268 269 Object retValObj = doCallHooks(thePointcut, theParams, true); 270 return (Boolean) retValObj; 271 } 272 273 private Object doCallHooks(POINTCUT thePointcut, HookParams theParams, Object theRetVal) { 274 // use new list for loop to avoid ConcurrentModificationException in case invoker gets added while looping 275 List<BaseInvoker> invokers = new ArrayList<>(getInvokersForPointcut(thePointcut)); 276 277 /* 278 * Call each hook in order 279 */ 280 for (BaseInvoker nextInvoker : invokers) { 281 Object nextOutcome = nextInvoker.invoke(theParams); 282 Class<?> pointcutReturnType = thePointcut.getReturnType(); 283 if (pointcutReturnType.equals(boolean.class)) { 284 Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome; 285 if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) { 286 ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker); 287 theRetVal = false; 288 break; 289 } 290 } else if (pointcutReturnType.equals(void.class) == false) { 291 if (nextOutcome != null) { 292 theRetVal = nextOutcome; 293 break; 294 } 295 } 296 } 297 298 return theRetVal; 299 } 300 301 @VisibleForTesting 302 List<Object> getInterceptorsWithInvokersForPointcut(POINTCUT thePointcut) { 303 return getInvokersForPointcut(thePointcut) 304 .stream() 305 .map(BaseInvoker::getInterceptor) 306 .collect(Collectors.toList()); 307 } 308 309 /** 310 * Returns an ordered list of invokers for the given pointcut. Note that 311 * a new and stable list is returned to.. do whatever you want with it. 312 */ 313 private List<BaseInvoker> getInvokersForPointcut(POINTCUT thePointcut) { 314 List<BaseInvoker> invokers; 315 316 synchronized (myRegistryMutex) { 317 List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut); 318 List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut); 319 List<BaseInvoker> threadLocalInvokers = null; 320 invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers); 321 } 322 323 return invokers; 324 } 325 326 /** 327 * First argument must be the global invoker list!! 328 */ 329 @SafeVarargs 330 private List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) { 331 List<BaseInvoker> haveOne = null; 332 boolean haveMultiple = false; 333 for (List<BaseInvoker> nextInvokerList : theInvokersLists) { 334 if (nextInvokerList == null || nextInvokerList.isEmpty()) { 335 continue; 336 } 337 338 if (haveOne == null) { 339 haveOne = nextInvokerList; 340 } else { 341 haveMultiple = true; 342 } 343 } 344 345 if (haveOne == null) { 346 return Collections.emptyList(); 347 } 348 349 List<BaseInvoker> retVal; 350 351 if (haveMultiple == false) { 352 353 // The global list doesn't need to be sorted every time since it's sorted on 354 // insertion each time. Doing so is a waste of cycles.. 355 if (haveOne == theInvokersLists[0]) { 356 retVal = haveOne; 357 } else { 358 retVal = new ArrayList<>(haveOne); 359 retVal.sort(Comparator.naturalOrder()); 360 } 361 362 } else { 363 364 retVal = Arrays 365 .stream(theInvokersLists) 366 .filter(Objects::nonNull) 367 .flatMap(Collection::stream) 368 .sorted() 369 .collect(Collectors.toList()); 370 371 } 372 373 return retVal; 374 } 375 376 /** 377 * Only call this when assertions are enabled, it's expensive 378 */ 379 final boolean haveAppropriateParams(POINTCUT thePointcut, HookParams theParams) { 380 if (theParams.getParamsForType().values().size() != thePointcut.getParameterTypes().size()) { 381 throw new IllegalArgumentException(Msg.code(1909) + String.format("Wrong number of params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), theParams.getParamsForType().values().stream().map(t -> t != null ? t.getClass().getSimpleName() : "null").sorted().collect(Collectors.toList()))); 382 } 383 384 List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes()); 385 386 ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType(); 387 for (Class<?> nextTypeClass : givenTypes.keySet()) { 388 String nextTypeName = nextTypeClass.getName(); 389 for (Object nextParamValue : givenTypes.get(nextTypeClass)) { 390 Validate.isTrue(nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), "Invalid params for pointcut %s - %s is not of type %s", thePointcut.name(), nextParamValue != null ? nextParamValue.getClass() : "null", nextTypeClass); 391 Validate.isTrue(wantedTypes.remove(nextTypeName), "Invalid params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), nextTypeName); 392 } 393 } 394 395 return true; 396 } 397 398 private List<HookInvoker> scanInterceptorAndAddToInvokerMultimap(Object theInterceptor, ListMultimap<POINTCUT, BaseInvoker> theInvokers) { 399 Class<?> interceptorClass = theInterceptor.getClass(); 400 int typeOrder = determineOrder(interceptorClass); 401 402 List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder); 403 404 // Invoke the REGISTERED pointcut for any added hooks 405 addedInvokers.stream() 406 .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut())) 407 .forEach(t -> t.invoke(new HookParams())); 408 409 // Register the interceptor and its various hooks 410 for (HookInvoker nextAddedHook : addedInvokers) { 411 POINTCUT nextPointcut = nextAddedHook.getPointcut(); 412 if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) { 413 continue; 414 } 415 theInvokers.put(nextPointcut, nextAddedHook); 416 } 417 418 // Make sure we're always sorted according to the order declared in @Order 419 for (POINTCUT nextPointcut : theInvokers.keys()) { 420 List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut); 421 nextInvokerList.sort(Comparator.naturalOrder()); 422 } 423 424 return addedInvokers; 425 } 426 427 /** 428 * @return Returns a list of any added invokers 429 */ 430 private List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) { 431 ArrayList<HookInvoker> retVal = new ArrayList<>(); 432 for (Method nextMethod : ReflectionUtil.getDeclaredMethods(theInterceptor.getClass(), true)) { 433 Optional<HookDescriptor> hook = scanForHook(nextMethod); 434 435 if (hook.isPresent()) { 436 int methodOrder = theTypeOrder; 437 int methodOrderAnnotation = hook.get().getOrder(); 438 if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) { 439 methodOrder = methodOrderAnnotation; 440 } 441 442 retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder)); 443 } 444 } 445 446 return retVal; 447 } 448 449 protected abstract Optional<HookDescriptor> scanForHook(Method nextMethod); 450 451 private class HookInvoker extends BaseInvoker { 452 453 private final Method myMethod; 454 private final Class<?>[] myParameterTypes; 455 private final int[] myParameterIndexes; 456 private final POINTCUT myPointcut; 457 458 /** 459 * Constructor 460 */ 461 private HookInvoker(HookDescriptor theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) { 462 super(theInterceptor, theOrder); 463 myPointcut = theHook.getPointcut(); 464 myParameterTypes = theHookMethod.getParameterTypes(); 465 myMethod = theHookMethod; 466 467 Class<?> returnType = theHookMethod.getReturnType(); 468 if (myPointcut.getReturnType().equals(boolean.class)) { 469 Validate.isTrue(boolean.class.equals(returnType) || void.class.equals(returnType), "Method does not return boolean or void: %s", theHookMethod); 470 } else if (myPointcut.getReturnType().equals(void.class)) { 471 Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod); 472 } else { 473 Validate.isTrue(myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), "Method does not return %s or void: %s", myPointcut.getReturnType(), theHookMethod); 474 } 475 476 myParameterIndexes = new int[myParameterTypes.length]; 477 Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>(); 478 for (int i = 0; i < myParameterTypes.length; i++) { 479 AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0)); 480 myParameterIndexes[i] = counter.getAndIncrement(); 481 } 482 483 myMethod.setAccessible(true); 484 } 485 486 @Override 487 public String toString() { 488 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) 489 .append("method", myMethod) 490 .toString(); 491 } 492 493 public POINTCUT getPointcut() { 494 return myPointcut; 495 } 496 497 /** 498 * @return Returns true/false if the hook method returns a boolean, returns true otherwise 499 */ 500 @Override 501 Object invoke(HookParams theParams) { 502 503 Object[] args = new Object[myParameterTypes.length]; 504 for (int i = 0; i < myParameterTypes.length; i++) { 505 Class<?> nextParamType = myParameterTypes[i]; 506 if (nextParamType.equals(Pointcut.class)) { 507 args[i] = myPointcut; 508 } else { 509 int nextParamIndex = myParameterIndexes[i]; 510 Object nextParamValue = theParams.get(nextParamType, nextParamIndex); 511 args[i] = nextParamValue; 512 } 513 } 514 515 // Invoke the method 516 try { 517 return myMethod.invoke(getInterceptor(), args); 518 } catch (InvocationTargetException e) { 519 Throwable targetException = e.getTargetException(); 520 if (myPointcut.isShouldLogAndSwallowException(targetException)) { 521 ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException); 522 return null; 523 } 524 525 if (targetException instanceof RuntimeException) { 526 throw ((RuntimeException) targetException); 527 } else { 528 throw new InternalErrorException(Msg.code(1910) + "Failure invoking interceptor for pointcut(s) " + getPointcut(), targetException); 529 } 530 } catch (Exception e) { 531 throw new InternalErrorException(Msg.code(1911) + e); 532 } 533 534 } 535 536 } 537 538 protected class HookDescriptor { 539 540 private final POINTCUT myPointcut; 541 private final int myOrder; 542 543 public HookDescriptor(POINTCUT thePointcut, int theOrder) { 544 myPointcut = thePointcut; 545 myOrder = theOrder; 546 } 547 548 POINTCUT getPointcut() { 549 return myPointcut; 550 } 551 552 int getOrder() { 553 return myOrder; 554 } 555 556 } 557 558 protected abstract static class BaseInvoker implements Comparable<BaseInvoker> { 559 560 private final int myOrder; 561 private final Object myInterceptor; 562 563 BaseInvoker(Object theInterceptor, int theOrder) { 564 myInterceptor = theInterceptor; 565 myOrder = theOrder; 566 } 567 568 public Object getInterceptor() { 569 return myInterceptor; 570 } 571 572 abstract Object invoke(HookParams theParams); 573 574 @Override 575 public int compareTo(BaseInvoker theInvoker) { 576 return myOrder - theInvoker.myOrder; 577 } 578 } 579 580 protected static <T extends Annotation> Optional<T> findAnnotation(AnnotatedElement theObject, Class<T> theHookClass) { 581 T annotation; 582 if (theObject instanceof Method) { 583 annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true); 584 } else { 585 annotation = theObject.getAnnotation(theHookClass); 586 } 587 return Optional.ofNullable(annotation); 588 } 589 590 private static int determineOrder(Class<?> theInterceptorClass) { 591 return findAnnotation(theInterceptorClass, Interceptor.class) 592 .map(Interceptor::order) 593 .orElse(Interceptor.DEFAULT_ORDER); 594 } 595 596 private static String toErrorString(List<String> theParameterTypes) { 597 return theParameterTypes 598 .stream() 599 .sorted() 600 .collect(Collectors.joining(",")); 601 } 602 603}