/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.smithy.model.transform.plugins;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.AuthDefinitionTrait;
import software.amazon.smithy.model.traits.ProtocolDefinitionTrait;
import software.amazon.smithy.model.transform.ModelTransformer;
import software.amazon.smithy.model.transform.ModelTransformerPlugin;

public final class CleanTraitDefinitions
implements ModelTransformerPlugin {
    @Override
    public Model onRemove(ModelTransformer transformer, Collection<Shape> removed, Model model) {
        Set<ShapeId> removedShapeIds = removed.stream().map(Shape::getId).collect(Collectors.toSet());
        model = transformer.replaceShapes(model, this.getAuthDefShapesToReplace(model, removedShapeIds));
        return transformer.replaceShapes(model, this.getProtocolDefShapesToReplace(model, removedShapeIds));
    }

    private Set<Shape> getAuthDefShapesToReplace(Model model, Set<ShapeId> removedShapeIds) {
        HashSet<Shape> shapes = new HashSet<Shape>();
        for (StructureShape structure : model.getStructureShapesWithTrait(AuthDefinitionTrait.class)) {
            List<ShapeId> newTraits;
            AuthDefinitionTrait authDefTrait = structure.expectTrait(AuthDefinitionTrait.class);
            List<ShapeId> traits = authDefTrait.getTraits();
            if (traits.equals(newTraits = this.excludeTraitsInSet(traits, removedShapeIds))) continue;
            shapes.add(((StructureShape.Builder)structure.toBuilder().addTrait(authDefTrait.toBuilder().traits(newTraits).build())).build());
        }
        return shapes;
    }

    private Set<Shape> getProtocolDefShapesToReplace(Model model, Set<ShapeId> removedShapeIds) {
        HashSet<Shape> shapes = new HashSet<Shape>();
        for (StructureShape structure : model.getStructureShapesWithTrait(ProtocolDefinitionTrait.class)) {
            List<ShapeId> newTraits;
            ProtocolDefinitionTrait protocolDefinitionTrait = structure.expectTrait(ProtocolDefinitionTrait.class);
            List<ShapeId> traits = protocolDefinitionTrait.getTraits();
            if (traits.equals(newTraits = this.excludeTraitsInSet(traits, removedShapeIds))) continue;
            shapes.add(((StructureShape.Builder)structure.toBuilder().addTrait(protocolDefinitionTrait.toBuilder().traits(newTraits).build())).build());
        }
        return shapes;
    }

    private List<ShapeId> excludeTraitsInSet(List<ShapeId> traits, Set<ShapeId> shapeIds) {
        return traits.stream().filter(trait -> !shapeIds.contains(trait)).collect(Collectors.toList());
    }
}

