Skip to content

Async inference model loader #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: inference-processor-simple
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.inference.AsyncModel;

import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.ml.inference.Model;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.function.BiConsumer;

/**
* A model (implements the inference function) that has its model state loaded
* from an index document via a {@link AsyncModelLoader}. When the AsyncModelLoader
* has fetched the document it will notify this class and subclasses then know how to
* construct the model.
*
* Any ingest documents arriving while waiting for the model state to load must be queued up.
*
* {@link #createModel(GetResponse)} should be implemented in subclasses to read
* the model state from the GetResponse supplied by the loader.
*
* {@link #inferPrivate(IngestDocument, BiConsumer)} does the actual inference.
*/
public abstract class AsyncModel implements Model {

private final boolean ignoreMissing;

private volatile boolean isLoaded = false;
private volatile Exception error;

private final Queue<Tuple<IngestDocument, BiConsumer<IngestDocument, Exception>>> documentQueue;

protected AsyncModel(boolean ignoreMissing) {
this.ignoreMissing = ignoreMissing;
documentQueue = new ConcurrentLinkedDeque<>();
}

@Override
public void infer(IngestDocument document, BiConsumer<IngestDocument, Exception> handler) {
if (isLoaded) {
inferPrivate(document, handler);
return;
}

if (error != null) {
handler.accept(null, error);
return;
}

// if we have a list of requests waiting to be used then they have to be queued up
queueRequest(document, handler);
}

/**
* Should be threadsafe
* @param document The ingest document
* @param handler Ingest handler
*/
protected abstract void inferPrivate(IngestDocument document, BiConsumer<IngestDocument, Exception> handler);



void imLoaded(GetResponse getResponse) {
createModel(getResponse);
drainQueuedToInfer();
isLoaded = true;
}

void setError(Exception exception) {
drainQueuedToError();
this.error = exception;
}

private synchronized void queueRequest(IngestDocument document, BiConsumer<IngestDocument, Exception> handler) {
documentQueue.add(new Tuple<>(document, handler));
}

private synchronized void drainQueuedToInfer() {
for (Tuple<IngestDocument, BiConsumer<IngestDocument, Exception>> request : documentQueue) {
inferPrivate(request.v1(), request.v2());
}
}

private synchronized void drainQueuedToError() {
for (Tuple<IngestDocument, BiConsumer<IngestDocument, Exception>> request : documentQueue) {
request.v2().accept(null, error);
}
}

public boolean isIgnoreMissing() {
return ignoreMissing;
}

protected abstract void createModel(GetResponse getResponse);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.inference.AsyncModel;

import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.xpack.ml.inference.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.ModelLoader;

import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

// This class is full of races.
//
// The general idea is that this class can be used to load any type of model where the
// model state has to be fetched from an index. TODO the class is poorly named

// The load() handles fetching the index document and will return a subclass of AsyncModel (type T)
// and register that object as a listener to be notified once the loading has finished
// or an error occurred. But the load() method can only be called once, if its called
// more than once it should wait for the loading to finish then notify all the listeners.

public abstract class AsyncModelLoader<T extends AsyncModel> implements ModelLoader {

private static final Logger logger = LogManager.getLogger(AsyncModelLoader.class);

public static final String INDEX = "index";

private final Client client;
private final Function<Boolean, T> modelSupplier;

private AtomicBoolean loadingFinished = new AtomicBoolean(false);
private volatile GetResponse response;
private volatile Exception loadingException;
private volatile T loadedListener;


protected AsyncModelLoader(Client client, Function<Boolean, T> modelSupplier) {
this.client = client;
this.modelSupplier = modelSupplier;
}

@Override
public T load(String modelId, String processorTag, boolean ignoreMissing, Map<String, Object> config) {
String index = readIndexName(processorTag, config);
String documentId = documentId(modelId, config);

// TODO if this method is called twice loadedListener will be overwritten.
loadedListener = modelSupplier.apply(ignoreMissing);
load(documentId, index);
return loadedListener;
}

@Override
public void consumeConfiguration(String processorTag, Map<String, Object> config) {
readIndexName(processorTag, config);
}

/**
* Read the name of the index to get the model state from.
* The default is to read the string value of object {@value #INDEX}.
*
* @param processorTag Tag
* @param config The processor config
* @return The name of the index containing the model
*/
protected String readIndexName(String processorTag, Map<String, Object> config) {
return ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX);
}

/**
* Construct the document Id used in the GET request.
* This function is intended to be overridden, this implementation simply returns {@code modelId}
*
* @param modelId The model Id
* @param config The processor config
* @return The document Id
*/
protected String documentId(String modelId, Map<String, Object> config) {
return modelId;
}

private void load(String id, String index) {
ActionListener<GetResponse> listener = ActionListener.wrap(this::setResponse, this::setLoadingException);

loadingFinished.compareAndSet(false, true);
client.prepareGet(index, null, id).execute(listener);
}

private synchronized void setResponse(GetResponse response) {

this.response = response;
loadingFinished.set(true);
if (loadedListener != null) {
loadedListener.imLoaded(response);
}
}

private void setLoadingException(Exception e) {
this.loadingException = e;
loadingFinished.set(true);

}

public GetResponse getGetResponse() {
return response;
}

public Exception getLoadingException() {
return loadingException;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,36 @@

package org.elasticsearch.xpack.ml.inference.sillymodel;

import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.ml.inference.Model;
import org.elasticsearch.xpack.ml.inference.AsyncModel.AsyncModel;

import java.util.Random;
import java.util.function.BiConsumer;

/**
* Trivial model whose only purpose is to aid code design
*/
public class SillyModel implements Model {
public class SillyModel extends AsyncModel {

private static final String TARGET_FIELD = "hotdog_or_not";

private final Random random;

public SillyModel() {
public SillyModel(boolean ignoreMissing) {
super(ignoreMissing);
random = Randomness.get();
}

public void infer(IngestDocument document, BiConsumer<IngestDocument, Exception> handler) {
@Override
public void inferPrivate(IngestDocument document, BiConsumer<IngestDocument, Exception> handler) {
document.setFieldValue(TARGET_FIELD, random.nextBoolean() ? "hotdog" : "not");
handler.accept(document, null);
}

@Override
protected void createModel(GetResponse getResponse) {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,13 @@

package org.elasticsearch.xpack.ml.inference.sillymodel;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.xpack.ml.inference.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.Model;
import org.elasticsearch.xpack.ml.inference.ModelLoader;
import org.elasticsearch.xpack.ml.inference.AsyncModel.AsyncModelLoader;

import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

public class SillyModelLoader implements ModelLoader {

public static final String MODEL_TYPE = "model_stored_in_index";

private static String INDEX = "index";

private final Client client;
public class SillyModelLoader extends AsyncModelLoader<SillyModel> {
public static final String MODEL_TYPE = "silly";

public SillyModelLoader(Client client) {
this.client = client;
}

@Override
public Model load(String modelId, String processorTag, boolean ignoreMissing, Map<String, Object> config) throws Exception {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Model> model = new AtomicReference<>();
AtomicReference<Exception> exception = new AtomicReference<>();

LatchedActionListener<Model> listener = new LatchedActionListener<>(
ActionListener.wrap(model::set, exception::set), latch
);

String index = ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX);

load(modelId, index, listener);
latch.await();
if (exception.get() != null) {
throw exception.get();
}

return model.get();
}

@Override
public void consumeConfiguration(String processorTag, Map<String, Object> config) {
ConfigurationUtils.readStringProperty(InferenceProcessor.TYPE, processorTag, config, INDEX);
}


private void load(String id, String index, ActionListener<Model> listener) {
client.prepareGet(index, null, id).execute(ActionListener.wrap(
response -> {
if (response.isExists()) {
listener.onResponse(new SillyModel());
} else {
listener.onFailure(new ResourceNotFoundException("missing model [{}], [{}]", id, index));
}
},
listener::onFailure
));
super(client, SillyModel::new);
}
}
Loading