001/*-
002 * #%L
003 * HAPI FHIR - Core Library
004 * %%
005 * Copyright (C) 2014 - 2023 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.interceptor.executor;
021
022import ca.uhn.fhir.i18n.Msg;
023import ca.uhn.fhir.interceptor.api.HookParams;
024import ca.uhn.fhir.interceptor.api.IBaseInterceptorBroadcaster;
025import ca.uhn.fhir.interceptor.api.IBaseInterceptorService;
026import ca.uhn.fhir.interceptor.api.IPointcut;
027import ca.uhn.fhir.interceptor.api.Interceptor;
028import ca.uhn.fhir.interceptor.api.Pointcut;
029import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
030import ca.uhn.fhir.util.ReflectionUtil;
031import com.google.common.annotations.VisibleForTesting;
032import com.google.common.collect.ArrayListMultimap;
033import com.google.common.collect.ListMultimap;
034import org.apache.commons.lang3.Validate;
035import org.apache.commons.lang3.builder.ToStringBuilder;
036import org.apache.commons.lang3.builder.ToStringStyle;
037import org.apache.commons.lang3.reflect.MethodUtils;
038import org.slf4j.Logger;
039import org.slf4j.LoggerFactory;
040
041import javax.annotation.Nonnull;
042import javax.annotation.Nullable;
043import java.lang.annotation.Annotation;
044import java.lang.reflect.AnnotatedElement;
045import java.lang.reflect.InvocationTargetException;
046import java.lang.reflect.Method;
047import java.util.ArrayList;
048import java.util.Arrays;
049import java.util.Collection;
050import java.util.Collections;
051import java.util.Comparator;
052import java.util.EnumSet;
053import java.util.HashMap;
054import java.util.IdentityHashMap;
055import java.util.List;
056import java.util.Map;
057import java.util.Objects;
058import java.util.Optional;
059import java.util.concurrent.atomic.AtomicInteger;
060import java.util.function.Predicate;
061import java.util.stream.Collectors;
062
063public abstract class BaseInterceptorService<POINTCUT extends Enum<POINTCUT> & IPointcut> implements IBaseInterceptorService<POINTCUT>, IBaseInterceptorBroadcaster<POINTCUT> {
064        private static final Logger ourLog = LoggerFactory.getLogger(BaseInterceptorService.class);
065        private final List<Object> myInterceptors = new ArrayList<>();
066        private final ListMultimap<POINTCUT, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create();
067        private final ListMultimap<POINTCUT, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create();
068        private final Object myRegistryMutex = new Object();
069        private final Class<POINTCUT> myPointcutType;
070        private volatile EnumSet<POINTCUT> myRegisteredPointcuts;
071        private String myName;
072        private boolean myWarnOnInterceptorWithNoHooks = true;
073
074        /**
075         * Constructor which uses a default name of "default"
076         */
077        public BaseInterceptorService(Class<POINTCUT> thePointcutType) {
078                this(thePointcutType, "default");
079        }
080
081        /**
082         * Constructor
083         *
084         * @param theName The name for this registry (useful for troubleshooting)
085         */
086        public BaseInterceptorService(Class<POINTCUT> thePointcutType, String theName) {
087                super();
088                myName = theName;
089                myPointcutType = thePointcutType;
090                rebuildRegisteredPointcutSet();
091        }
092
093        /**
094         * Should a warning be issued if an interceptor is registered and it has no hooks
095         */
096        public void setWarnOnInterceptorWithNoHooks(boolean theWarnOnInterceptorWithNoHooks) {
097                myWarnOnInterceptorWithNoHooks = theWarnOnInterceptorWithNoHooks;
098        }
099
100        @VisibleForTesting
101        List<Object> getGlobalInterceptorsForUnitTest() {
102                return myInterceptors;
103        }
104
105        public void setName(String theName) {
106                myName = theName;
107        }
108
109        protected void registerAnonymousInterceptor(POINTCUT thePointcut, Object theInterceptor, BaseInvoker theInvoker) {
110                Validate.notNull(thePointcut);
111                Validate.notNull(theInterceptor);
112                synchronized (myRegistryMutex) {
113
114                        myAnonymousInvokers.put(thePointcut, theInvoker);
115                        if (!isInterceptorAlreadyRegistered(theInterceptor)) {
116                                myInterceptors.add(theInterceptor);
117                        }
118
119                        rebuildRegisteredPointcutSet();
120                }
121        }
122
123        @Override
124        public List<Object> getAllRegisteredInterceptors() {
125                synchronized (myRegistryMutex) {
126                        List<Object> retVal = new ArrayList<>(myInterceptors);
127                        return Collections.unmodifiableList(retVal);
128                }
129        }
130
131        @Override
132        @VisibleForTesting
133        public void unregisterAllInterceptors() {
134                synchronized (myRegistryMutex) {
135                        unregisterInterceptors(myAnonymousInvokers.values());
136                        unregisterInterceptors(myGlobalInvokers.values());
137                        unregisterInterceptors(myInterceptors);
138                }
139        }
140
141        @Override
142        public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) {
143                if (theInterceptors != null) {
144                        // We construct a new list before iterating because the service's internal
145                        // interceptor lists get passed into this method, and we get concurrent
146                        // modification errors if we modify them at the same time as we iterate them
147                        new ArrayList<>(theInterceptors).forEach(this::unregisterInterceptor);
148                }
149        }
150
151        @Override
152        public void registerInterceptors(@Nullable Collection<?> theInterceptors) {
153                if (theInterceptors != null) {
154                        theInterceptors.forEach(this::registerInterceptor);
155                }
156        }
157
158        @Override
159        public void unregisterAllAnonymousInterceptors() {
160                synchronized (myRegistryMutex) {
161                        unregisterInterceptorsIf(t -> true, myAnonymousInvokers);
162                }
163        }
164
165        @Override
166        public void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction) {
167                unregisterInterceptorsIf(theShouldUnregisterFunction, myGlobalInvokers);
168                unregisterInterceptorsIf(theShouldUnregisterFunction, myAnonymousInvokers);
169        }
170
171        private void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction, ListMultimap<POINTCUT, BaseInvoker> theGlobalInvokers) {
172                synchronized (myRegistryMutex) {
173                        for (Map.Entry<POINTCUT, BaseInvoker> nextInvoker : new ArrayList<>(theGlobalInvokers.entries())) {
174                                if (theShouldUnregisterFunction.test(nextInvoker.getValue().getInterceptor())) {
175                                        unregisterInterceptor(nextInvoker.getValue().getInterceptor());
176                                }
177                        }
178
179                        rebuildRegisteredPointcutSet();
180                }
181        }
182
183        @Override
184        public boolean registerInterceptor(Object theInterceptor) {
185                synchronized (myRegistryMutex) {
186
187                        if (isInterceptorAlreadyRegistered(theInterceptor)) {
188                                return false;
189                        }
190
191                        List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers);
192                        if (addedInvokers.isEmpty()) {
193                                if (myWarnOnInterceptorWithNoHooks) {
194                                        ourLog.warn("Interceptor registered with no valid hooks - Type was: {}", theInterceptor.getClass().getName());
195                                }
196                                return false;
197                        }
198
199                        // Add to the global list
200                        myInterceptors.add(theInterceptor);
201                        sortByOrderAnnotation(myInterceptors);
202
203                        rebuildRegisteredPointcutSet();
204
205                        return true;
206                }
207        }
208
209        private void rebuildRegisteredPointcutSet() {
210                EnumSet<POINTCUT> registeredPointcuts = EnumSet.noneOf(myPointcutType);
211                registeredPointcuts.addAll(myAnonymousInvokers.keySet());
212                registeredPointcuts.addAll(myGlobalInvokers.keySet());
213                myRegisteredPointcuts = registeredPointcuts;
214        }
215
216        private boolean isInterceptorAlreadyRegistered(Object theInterceptor) {
217                for (Object next : myInterceptors) {
218                        if (next == theInterceptor) {
219                                return true;
220                        }
221                }
222                return false;
223        }
224
225        @Override
226        public boolean unregisterInterceptor(Object theInterceptor) {
227                synchronized (myRegistryMutex) {
228                        boolean removed = myInterceptors.removeIf(t -> t == theInterceptor);
229                        removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
230                        removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
231                        rebuildRegisteredPointcutSet();
232                        return removed;
233                }
234        }
235
236        private void sortByOrderAnnotation(List<Object> theObjects) {
237                IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>();
238                for (Object next : theObjects) {
239                        Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class);
240                        int order = orderAnnotation != null ? orderAnnotation.order() : 0;
241                        interceptorToOrder.put(next, order);
242                }
243
244                theObjects.sort((a, b) -> {
245                        Integer orderA = interceptorToOrder.get(a);
246                        Integer orderB = interceptorToOrder.get(b);
247                        return orderA - orderB;
248                });
249        }
250
251        @Override
252        public Object callHooksAndReturnObject(POINTCUT thePointcut, HookParams theParams) {
253                assert haveAppropriateParams(thePointcut, theParams);
254                assert thePointcut.getReturnType() != void.class;
255
256                return doCallHooks(thePointcut, theParams, null);
257        }
258
259        @Override
260        public boolean hasHooks(POINTCUT thePointcut) {
261                return myRegisteredPointcuts.contains(thePointcut);
262        }
263
264        @Override
265        public boolean callHooks(POINTCUT thePointcut, HookParams theParams) {
266                assert haveAppropriateParams(thePointcut, theParams);
267                assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == boolean.class;
268
269                Object retValObj = doCallHooks(thePointcut, theParams, true);
270                return (Boolean) retValObj;
271        }
272
273        private Object doCallHooks(POINTCUT thePointcut, HookParams theParams, Object theRetVal) {
274                // use new list for loop to avoid ConcurrentModificationException in case invoker gets added while looping
275                List<BaseInvoker> invokers = new ArrayList<>(getInvokersForPointcut(thePointcut));
276
277                /*
278                 * Call each hook in order
279                 */
280                for (BaseInvoker nextInvoker : invokers) {
281                        Object nextOutcome = nextInvoker.invoke(theParams);
282                        Class<?> pointcutReturnType = thePointcut.getReturnType();
283                        if (pointcutReturnType.equals(boolean.class)) {
284                                Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome;
285                                if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) {
286                                        ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker);
287                                        theRetVal = false;
288                                        break;
289                                }
290                        } else if (pointcutReturnType.equals(void.class) == false) {
291                                if (nextOutcome != null) {
292                                        theRetVal = nextOutcome;
293                                        break;
294                                }
295                        }
296                }
297
298                return theRetVal;
299        }
300
301        @VisibleForTesting
302        List<Object> getInterceptorsWithInvokersForPointcut(POINTCUT thePointcut) {
303                return getInvokersForPointcut(thePointcut)
304                        .stream()
305                        .map(BaseInvoker::getInterceptor)
306                        .collect(Collectors.toList());
307        }
308
309        /**
310         * Returns an ordered list of invokers for the given pointcut. Note that
311         * a new and stable list is returned to.. do whatever you want with it.
312         */
313        private List<BaseInvoker> getInvokersForPointcut(POINTCUT thePointcut) {
314                List<BaseInvoker> invokers;
315
316                synchronized (myRegistryMutex) {
317                        List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut);
318                        List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut);
319                        List<BaseInvoker> threadLocalInvokers = null;
320                        invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers);
321                }
322
323                return invokers;
324        }
325
326        /**
327         * First argument must be the global invoker list!!
328         */
329        @SafeVarargs
330        private List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) {
331                List<BaseInvoker> haveOne = null;
332                boolean haveMultiple = false;
333                for (List<BaseInvoker> nextInvokerList : theInvokersLists) {
334                        if (nextInvokerList == null || nextInvokerList.isEmpty()) {
335                                continue;
336                        }
337
338                        if (haveOne == null) {
339                                haveOne = nextInvokerList;
340                        } else {
341                                haveMultiple = true;
342                        }
343                }
344
345                if (haveOne == null) {
346                        return Collections.emptyList();
347                }
348
349                List<BaseInvoker> retVal;
350
351                if (haveMultiple == false) {
352
353                        // The global list doesn't need to be sorted every time since it's sorted on
354                        // insertion each time. Doing so is a waste of cycles..
355                        if (haveOne == theInvokersLists[0]) {
356                                retVal = haveOne;
357                        } else {
358                                retVal = new ArrayList<>(haveOne);
359                                retVal.sort(Comparator.naturalOrder());
360                        }
361
362                } else {
363
364                        retVal = Arrays
365                                .stream(theInvokersLists)
366                                .filter(Objects::nonNull)
367                                .flatMap(Collection::stream)
368                                .sorted()
369                                .collect(Collectors.toList());
370
371                }
372
373                return retVal;
374        }
375
376        /**
377         * Only call this when assertions are enabled, it's expensive
378         */
379        final boolean haveAppropriateParams(POINTCUT thePointcut, HookParams theParams) {
380                if (theParams.getParamsForType().values().size() != thePointcut.getParameterTypes().size()) {
381                        throw new IllegalArgumentException(Msg.code(1909) + String.format("Wrong number of params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), theParams.getParamsForType().values().stream().map(t -> t != null ? t.getClass().getSimpleName() : "null").sorted().collect(Collectors.toList())));
382                }
383
384                List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes());
385
386                ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType();
387                for (Class<?> nextTypeClass : givenTypes.keySet()) {
388                        String nextTypeName = nextTypeClass.getName();
389                        for (Object nextParamValue : givenTypes.get(nextTypeClass)) {
390                                Validate.isTrue(nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), "Invalid params for pointcut %s - %s is not of type %s", thePointcut.name(), nextParamValue != null ? nextParamValue.getClass() : "null", nextTypeClass);
391                                Validate.isTrue(wantedTypes.remove(nextTypeName), "Invalid params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), nextTypeName);
392                        }
393                }
394
395                return true;
396        }
397
398        private List<HookInvoker> scanInterceptorAndAddToInvokerMultimap(Object theInterceptor, ListMultimap<POINTCUT, BaseInvoker> theInvokers) {
399                Class<?> interceptorClass = theInterceptor.getClass();
400                int typeOrder = determineOrder(interceptorClass);
401
402                List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder);
403
404                // Invoke the REGISTERED pointcut for any added hooks
405                addedInvokers.stream()
406                        .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut()))
407                        .forEach(t -> t.invoke(new HookParams()));
408
409                // Register the interceptor and its various hooks
410                for (HookInvoker nextAddedHook : addedInvokers) {
411                        POINTCUT nextPointcut = nextAddedHook.getPointcut();
412                        if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) {
413                                continue;
414                        }
415                        theInvokers.put(nextPointcut, nextAddedHook);
416                }
417
418                // Make sure we're always sorted according to the order declared in @Order
419                for (POINTCUT nextPointcut : theInvokers.keys()) {
420                        List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut);
421                        nextInvokerList.sort(Comparator.naturalOrder());
422                }
423
424                return addedInvokers;
425        }
426
427        /**
428         * @return Returns a list of any added invokers
429         */
430        private List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) {
431                ArrayList<HookInvoker> retVal = new ArrayList<>();
432                for (Method nextMethod : ReflectionUtil.getDeclaredMethods(theInterceptor.getClass(), true)) {
433                        Optional<HookDescriptor> hook = scanForHook(nextMethod);
434
435                        if (hook.isPresent()) {
436                                int methodOrder = theTypeOrder;
437                                int methodOrderAnnotation = hook.get().getOrder();
438                                if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) {
439                                        methodOrder = methodOrderAnnotation;
440                                }
441
442                                retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder));
443                        }
444                }
445
446                return retVal;
447        }
448
449        protected abstract Optional<HookDescriptor> scanForHook(Method nextMethod);
450
451        private class HookInvoker extends BaseInvoker {
452
453                private final Method myMethod;
454                private final Class<?>[] myParameterTypes;
455                private final int[] myParameterIndexes;
456                private final POINTCUT myPointcut;
457
458                /**
459                 * Constructor
460                 */
461                private HookInvoker(HookDescriptor theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) {
462                        super(theInterceptor, theOrder);
463                        myPointcut = theHook.getPointcut();
464                        myParameterTypes = theHookMethod.getParameterTypes();
465                        myMethod = theHookMethod;
466
467                        Class<?> returnType = theHookMethod.getReturnType();
468                        if (myPointcut.getReturnType().equals(boolean.class)) {
469                                Validate.isTrue(boolean.class.equals(returnType) || void.class.equals(returnType), "Method does not return boolean or void: %s", theHookMethod);
470                        } else if (myPointcut.getReturnType().equals(void.class)) {
471                                Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod);
472                        } else {
473                                Validate.isTrue(myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), "Method does not return %s or void: %s", myPointcut.getReturnType(), theHookMethod);
474                        }
475
476                        myParameterIndexes = new int[myParameterTypes.length];
477                        Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>();
478                        for (int i = 0; i < myParameterTypes.length; i++) {
479                                AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0));
480                                myParameterIndexes[i] = counter.getAndIncrement();
481                        }
482
483                        myMethod.setAccessible(true);
484                }
485
486                @Override
487                public String toString() {
488                        return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
489                                .append("method", myMethod)
490                                .toString();
491                }
492
493                public POINTCUT getPointcut() {
494                        return myPointcut;
495                }
496
497                /**
498                 * @return Returns true/false if the hook method returns a boolean, returns true otherwise
499                 */
500                @Override
501                Object invoke(HookParams theParams) {
502
503                        Object[] args = new Object[myParameterTypes.length];
504                        for (int i = 0; i < myParameterTypes.length; i++) {
505                                Class<?> nextParamType = myParameterTypes[i];
506                                if (nextParamType.equals(Pointcut.class)) {
507                                        args[i] = myPointcut;
508                                } else {
509                                        int nextParamIndex = myParameterIndexes[i];
510                                        Object nextParamValue = theParams.get(nextParamType, nextParamIndex);
511                                        args[i] = nextParamValue;
512                                }
513                        }
514
515                        // Invoke the method
516                        try {
517                                return myMethod.invoke(getInterceptor(), args);
518                        } catch (InvocationTargetException e) {
519                                Throwable targetException = e.getTargetException();
520                                if (myPointcut.isShouldLogAndSwallowException(targetException)) {
521                                        ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException);
522                                        return null;
523                                }
524
525                                if (targetException instanceof RuntimeException) {
526                                        throw ((RuntimeException) targetException);
527                                } else {
528                                        throw new InternalErrorException(Msg.code(1910) + "Failure invoking interceptor for pointcut(s) " + getPointcut(), targetException);
529                                }
530                        } catch (Exception e) {
531                                throw new InternalErrorException(Msg.code(1911) + e);
532                        }
533
534                }
535
536        }
537
538        protected class HookDescriptor {
539
540                private final POINTCUT myPointcut;
541                private final int myOrder;
542
543                public HookDescriptor(POINTCUT thePointcut, int theOrder) {
544                        myPointcut = thePointcut;
545                        myOrder = theOrder;
546                }
547
548                POINTCUT getPointcut() {
549                        return myPointcut;
550                }
551
552                int getOrder() {
553                        return myOrder;
554                }
555
556        }
557
558        protected abstract static class BaseInvoker implements Comparable<BaseInvoker> {
559
560                private final int myOrder;
561                private final Object myInterceptor;
562
563                BaseInvoker(Object theInterceptor, int theOrder) {
564                        myInterceptor = theInterceptor;
565                        myOrder = theOrder;
566                }
567
568                public Object getInterceptor() {
569                        return myInterceptor;
570                }
571
572                abstract Object invoke(HookParams theParams);
573
574                @Override
575                public int compareTo(BaseInvoker theInvoker) {
576                        return myOrder - theInvoker.myOrder;
577                }
578        }
579
580        protected static <T extends Annotation> Optional<T> findAnnotation(AnnotatedElement theObject, Class<T> theHookClass) {
581                T annotation;
582                if (theObject instanceof Method) {
583                        annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true);
584                } else {
585                        annotation = theObject.getAnnotation(theHookClass);
586                }
587                return Optional.ofNullable(annotation);
588        }
589
590        private static int determineOrder(Class<?> theInterceptorClass) {
591                return findAnnotation(theInterceptorClass, Interceptor.class)
592                        .map(Interceptor::order)
593                        .orElse(Interceptor.DEFAULT_ORDER);
594        }
595
596        private static String toErrorString(List<String> theParameterTypes) {
597                return theParameterTypes
598                        .stream()
599                        .sorted()
600                        .collect(Collectors.joining(","));
601        }
602
603}