Get TabNet online inferences
Stay organized with collections
Save and categorize content based on your preferences.
This page shows you how to get online (real-time) inferences and explanations
from your tabular classification or regression models using the Google Cloud console
or the Vertex AI API.
An online inference is a synchronous request as opposed to a
batch inference,
which is an asynchronous request. Use online inferences when you are making
requests in response to application input or in other situations where you
require timely inference.
You must deploy a model to an endpoint before that model can be used to serve
online inferences. Deploying a model associates physical resources with the
model so it can serve online inferences with low latency.
Before you can get online inferences, you must first
train
a model.
Deploy a model to an endpoint
You can deploy more than one model to an endpoint, and you can deploy a model to
more than one endpoint. For more information about options and use cases for
deploying models, see About deploying models.
Use one of the following methods to deploy a model:
Google Cloud console
In the Google Cloud console, in the Vertex AI section, go to
the Models page.
Click the name of the model you want to deploy to open its details page.
Select the Deploy & Test tab.
If your model is already deployed to any endpoints, they are listed in the
Deploy your model section.
Click Deploy to endpoint.
In the Define your endpoint page, configure as follows:
You can choose to deploy your model to a new endpoint or an existing endpoint.
To deploy your model to a new endpoint, select
radio_button_checkedCreate new endpoint
and provide a name for the new endpoint.
To deploy your model to an existing endpoint, select
radio_button_checkedAdd to existing endpoint
and select the endpoint from the drop-down list.
You can add more than one model to an endpoint, and you can add a model
to more than one endpoint. Learn more.
Click Continue.
In the Model settings page, configure as follows:
If you're deploying your model to a new endpoint, accept 100 for the
Traffic split. If you're deploying your model to an existing endpoint that has one or
more models deployed to it, you must update the Traffic split
percentage for the model you are deploying and the already deployed models
so that all of the percentages add up to 100%.
Enter the Minimum number of compute nodes you want to provide for
your model.
This is the number of nodes available to this model at all times.
You are charged for the nodes used, whether to handle inference load
or for standby (minimum) nodes, even without inference traffic.
See the pricing page.
Select your Machine type.
Larger machine resources will increase your inference performance
and increase costs.
importcom.google.api.gax.longrunning.OperationFuture;importcom.google.cloud.aiplatform.v1.CreateEndpointOperationMetadata;importcom.google.cloud.aiplatform.v1.Endpoint;importcom.google.cloud.aiplatform.v1.EndpointServiceClient;importcom.google.cloud.aiplatform.v1.EndpointServiceSettings;importcom.google.cloud.aiplatform.v1.LocationName;importjava.io.IOException;importjava.util.concurrent.ExecutionException;importjava.util.concurrent.TimeUnit;importjava.util.concurrent.TimeoutException;publicclassCreateEndpointSample{publicstaticvoidmain(String[]args)throwsIOException,InterruptedException,ExecutionException,TimeoutException{// TODO(developer): Replace these variables before running the sample.Stringproject="YOUR_PROJECT_ID";StringendpointDisplayName="YOUR_ENDPOINT_DISPLAY_NAME";createEndpointSample(project,endpointDisplayName);}staticvoidcreateEndpointSample(Stringproject,StringendpointDisplayName)throwsIOException,InterruptedException,ExecutionException,TimeoutException{EndpointServiceSettingsendpointServiceSettings=EndpointServiceSettings.newBuilder().setEndpoint("us-central1-aiplatform.googleapis.com:443").build();// Initialize client that will be used to send requests. This client only needs to be created// once, and can be reused for multiple requests. After completing all of your requests, call// the "close" method on the client to safely clean up any remaining background resources.try(EndpointServiceClientendpointServiceClient=EndpointServiceClient.create(endpointServiceSettings)){Stringlocation="us-central1";LocationNamelocationName=LocationName.of(project,location);Endpointendpoint=Endpoint.newBuilder().setDisplayName(endpointDisplayName).build();OperationFuture<Endpoint,CreateEndpointOperationMetadata>endpointFuture=endpointServiceClient.createEndpointAsync(locationName,endpoint);System.out.format("Operation name: %s\n",endpointFuture.getInitialFuture().get().getName());System.out.println("Waiting for operation to finish...");EndpointendpointResponse=endpointFuture.get(300,TimeUnit.SECONDS);System.out.println("Create Endpoint Response");System.out.format("Name: %s\n",endpointResponse.getName());System.out.format("Display Name: %s\n",endpointResponse.getDisplayName());System.out.format("Description: %s\n",endpointResponse.getDescription());System.out.format("Labels: %s\n",endpointResponse.getLabelsMap());System.out.format("Create Time: %s\n",endpointResponse.getCreateTime());System.out.format("Update Time: %s\n",endpointResponse.getUpdateTime());}}}
/** * TODO(developer): Uncomment these variables before running the sample.\ * (Not necessary if passing values as arguments) */// const endpointDisplayName = 'YOUR_ENDPOINT_DISPLAY_NAME';// const project = 'YOUR_PROJECT_ID';// const location = 'YOUR_PROJECT_LOCATION';// Imports the Google Cloud Endpoint Service Client libraryconst{EndpointServiceClient}=require('@google-cloud/aiplatform');// Specifies the location of the api endpointconstclientOptions={apiEndpoint:'us-central1-aiplatform.googleapis.com',};// Instantiates a clientconstendpointServiceClient=newEndpointServiceClient(clientOptions);asyncfunctioncreateEndpoint(){// Configure the parent resourceconstparent=`projects/${project}/locations/${location}`;constendpoint={displayName:endpointDisplayName,};constrequest={parent,endpoint,};// Get and print out a list of all the endpoints for this resourceconst[response]=awaitendpointServiceClient.createEndpoint(request);console.log(`Long running operation : ${response.name}`);// Wait for operation to completeawaitresponse.promise();constresult=response.result;console.log('Create endpoint response');console.log(`\tName : ${result.name}`);console.log(`\tDisplay name : ${result.displayName}`);console.log(`\tDescription : ${result.description}`);console.log(`\tLabels : ${JSON.stringify(result.labels)}`);console.log(`\tCreate time : ${JSON.stringify(result.createTime)}`);console.log(`\tUpdate time : ${JSON.stringify(result.updateTime)}`);}createEndpoint();
To make an online inference, submit one or more test items to a model for
analysis, and the model returns results that are based on your model's
objective. Use the Google Cloud console or the Vertex AI API to request an
online inference.
Google Cloud console
In the Google Cloud console, in the Vertex AI section, go to
the Models page.
From the list of models, click the name of the model to request inferences
from.
Select the Deploy & test tab.
Under the Test your model section, add test items to request an
inference. The baseline inference data is filled in
for you, or you can enter your own inference data and click Predict.
After the inference is complete, Vertex AI returns the results in
the console.
API: Classification
gcloud
Create a file named request.json with the following contents:
{
"instances": [
{
PREDICTION_DATA_ROW
}
]
}
Replace the following:
PREDICTION_DATA_ROW: A JSON object with keys as the feature names and values as the
corresponding feature values. For example, for a dataset with a number, an array
of strings, and a category, the row of data might look like the following example request:
A value must be provided for every feature included in training. The format of the data used for
prediction must match the format used for training. Refer to
Data format for predictions
for details.
PREDICTION_DATA_ROW: A JSON object with keys as the feature names and values as the
corresponding feature values. For example, for a dataset with a number, an array
of strings, and a category, the row of data might look like the following example request:
A value must be provided for every feature included in training. The format of the data used for
prediction must match the format used for training. Refer to
Data format for predictions
for details.
DEPLOYED_MODEL_ID: Output by the predict method. The ID of the
model used to generate the inference.
HTTP method and URL:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict
Request JSON body:
{
"instances": [
{
PREDICTION_DATA_ROW
}
]
}
To send your request, choose one of these options:
curl
Save the request body in a file named request.json,
and execute the following command:
importcom.google.cloud.aiplatform.util.ValueConverter;importcom.google.cloud.aiplatform.v1.EndpointName;importcom.google.cloud.aiplatform.v1.PredictResponse;importcom.google.cloud.aiplatform.v1.PredictionServiceClient;importcom.google.cloud.aiplatform.v1.PredictionServiceSettings;importcom.google.cloud.aiplatform.v1.schema.predict.prediction.TabularClassificationPredictionResult;importcom.google.protobuf.ListValue;importcom.google.protobuf.Value;importcom.google.protobuf.util.JsonFormat;importjava.io.IOException;importjava.util.List;publicclassPredictTabularClassificationSample{publicstaticvoidmain(String[]args)throwsIOException{// TODO(developer): Replace these variables before running the sample.Stringproject="YOUR_PROJECT_ID";Stringinstance="[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]";StringendpointId="YOUR_ENDPOINT_ID";predictTabularClassification(instance,project,endpointId);}staticvoidpredictTabularClassification(Stringinstance,Stringproject,StringendpointId)throwsIOException{PredictionServiceSettingspredictionServiceSettings=PredictionServiceSettings.newBuilder().setEndpoint("us-central1-aiplatform.googleapis.com:443").build();// Initialize client that will be used to send requests. This client only needs to be created// once, and can be reused for multiple requests. After completing all of your requests, call// the "close" method on the client to safely clean up any remaining background resources.try(PredictionServiceClientpredictionServiceClient=PredictionServiceClient.create(predictionServiceSettings)){Stringlocation="us-central1";EndpointNameendpointName=EndpointName.of(project,location,endpointId);ListValue.BuilderlistValue=ListValue.newBuilder();JsonFormat.parser().merge(instance,listValue);List<Value>instanceList=listValue.getValuesList();Valueparameters=Value.newBuilder().setListValue(listValue).build();PredictResponsepredictResponse=predictionServiceClient.predict(endpointName,instanceList,parameters);System.out.println("Predict Tabular Classification Response");System.out.format("\tDeployed Model Id: %s\n",predictResponse.getDeployedModelId());System.out.println("Predictions");for(Valueprediction:predictResponse.getPredictionsList()){TabularClassificationPredictionResult.BuilderresultBuilder=TabularClassificationPredictionResult.newBuilder();TabularClassificationPredictionResultresult=(TabularClassificationPredictionResult)ValueConverter.fromValue(resultBuilder,prediction);for(inti=0;i < result.getClassesCount();i++){System.out.printf("\tClass: %s",result.getClasses(i));System.out.printf("\tScore: %f",result.getScores(i));}}}}}
/** * TODO(developer): Uncomment these variables before running the sample.\ * (Not necessary if passing values as arguments) */// const endpointId = 'YOUR_ENDPOINT_ID';// const project = 'YOUR_PROJECT_ID';// const location = 'YOUR_PROJECT_LOCATION';constaiplatform=require('@google-cloud/aiplatform');const{prediction}=aiplatform.protos.google.cloud.aiplatform.v1.schema.predict;// Imports the Google Cloud Prediction service clientconst{PredictionServiceClient}=aiplatform.v1;// Import the helper module for converting arbitrary protobuf.Value objects.const{helpers}=aiplatform;// Specifies the location of the api endpointconstclientOptions={apiEndpoint:'us-central1-aiplatform.googleapis.com',};// Instantiates a clientconstpredictionServiceClient=newPredictionServiceClient(clientOptions);asyncfunctionpredictTablesClassification(){// Configure the endpoint resourceconstendpoint=`projects/${project}/locations/${location}/endpoints/${endpointId}`;constparameters=helpers.toValue({});constinstance=helpers.toValue({petal_length:'1.4',petal_width:'1.3',sepal_length:'5.1',sepal_width:'2.8',});constinstances=[instance];constrequest={endpoint,instances,parameters,};// Predict requestconst[response]=awaitpredictionServiceClient.predict(request);console.log('Predict tabular classification response');console.log(`\tDeployed model id : ${response.deployedModelId}\n`);constpredictions=response.predictions;console.log('Predictions :');for(constpredictionResultValofpredictions){constpredictionResultObj=prediction.TabularClassificationPredictionResult.fromValue(predictionResultVal);for(const[i,class_]ofpredictionResultObj.classes.entries()){console.log(`\tClass: ${class_}`);console.log(`\tScore: ${predictionResultObj.scores[i]}\n\n`);}}}predictTablesClassification();
defpredict_tabular_classification_sample(project:str,location:str,endpoint_name:str,instances:List[Dict],):""" Args project: Your project ID or project number. location: Region where Endpoint is located. For example, 'us-central1'. endpoint_name: A fully qualified endpoint name or endpoint ID. Example: "projects/123/locations/us-central1/endpoints/456" or "456" when project and location are initialized or passed. instances: A list of one or more instances (examples) to return a prediction for. """aiplatform.init(project=project,location=location)endpoint=aiplatform.Endpoint(endpoint_name)response=endpoint.predict(instances=instances)forprediction_inresponse.predictions:print(prediction_)
API: Regression
gcloud
Create a file named `request.json` with the following contents:
{
"instances": [
{
PREDICTION_DATA_ROW
}
]
}
Replace the following:
PREDICTION_DATA_ROW: A JSON object with keys as the feature names and values as the
corresponding feature values. For example, for a dataset with a number, an array
of numbers, and a category, the row of data might look like the following example request:
"age":3.6,
"sq_ft":5392,
"code": "90331"
A value must be provided for every feature included in training. The format of the data used for
prediction must match the format used for training. Refer to
Data format for predictions
for details.
LOCATION_ID: The region where you are using Vertex AI.
REST
You use the
endpoints.predict
method to request an online inference.
Before using any of the request data,
make the following replacements:
LOCATION_ID: Region where Endpoint is located. For example, us-central1.
PROJECT_ID: .
ENDPOINT_ID: The ID for the endpoint.
PREDICTION_DATA_ROW: A JSON object with keys as the feature names and values as the
corresponding feature values. For example, for a dataset with a number, an array
of numbers, and a category, the row of data might look like the following example request:
"age":3.6,
"sq_ft":5392,
"code": "90331"
A value must be provided for every feature included in training. The format of the data used for
prediction must match the format used for training. Refer to
Data format for predictions
for details.
DEPLOYED_MODEL_ID: Output by the predict method. The ID of the
model used to generate the inference.
HTTP method and URL:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict
Request JSON body:
{
"instances": [
{
PREDICTION_DATA_ROW
}
]
}
To send your request, choose one of these options:
curl
Save the request body in a file named request.json,
and execute the following command:
importcom.google.cloud.aiplatform.util.ValueConverter;importcom.google.cloud.aiplatform.v1.EndpointName;importcom.google.cloud.aiplatform.v1.PredictResponse;importcom.google.cloud.aiplatform.v1.PredictionServiceClient;importcom.google.cloud.aiplatform.v1.PredictionServiceSettings;importcom.google.cloud.aiplatform.v1.schema.predict.prediction.TabularRegressionPredictionResult;importcom.google.protobuf.ListValue;importcom.google.protobuf.Value;importcom.google.protobuf.util.JsonFormat;importjava.io.IOException;importjava.util.List;publicclassPredictTabularRegressionSample{publicstaticvoidmain(String[]args)throwsIOException{// TODO(developer): Replace these variables before running the sample.Stringproject="YOUR_PROJECT_ID";Stringinstance="[{ “feature_column_a”: “value”, “feature_column_b”: “value”}]";StringendpointId="YOUR_ENDPOINT_ID";predictTabularRegression(instance,project,endpointId);}staticvoidpredictTabularRegression(Stringinstance,Stringproject,StringendpointId)throwsIOException{PredictionServiceSettingspredictionServiceSettings=PredictionServiceSettings.newBuilder().setEndpoint("us-central1-aiplatform.googleapis.com:443").build();// Initialize client that will be used to send requests. This client only needs to be created// once, and can be reused for multiple requests. After completing all of your requests, call// the "close" method on the client to safely clean up any remaining background resources.try(PredictionServiceClientpredictionServiceClient=PredictionServiceClient.create(predictionServiceSettings)){Stringlocation="us-central1";EndpointNameendpointName=EndpointName.of(project,location,endpointId);ListValue.BuilderlistValue=ListValue.newBuilder();JsonFormat.parser().merge(instance,listValue);List<Value>instanceList=listValue.getValuesList();Valueparameters=Value.newBuilder().setListValue(listValue).build();PredictResponsepredictResponse=predictionServiceClient.predict(endpointName,instanceList,parameters);System.out.println("Predict Tabular Regression Response");System.out.format("\tDisplay Model Id: %s\n",predictResponse.getDeployedModelId());System.out.println("Predictions");for(Valueprediction:predictResponse.getPredictionsList()){TabularRegressionPredictionResult.BuilderresultBuilder=TabularRegressionPredictionResult.newBuilder();TabularRegressionPredictionResultresult=(TabularRegressionPredictionResult)ValueConverter.fromValue(resultBuilder,prediction);System.out.printf("\tUpper bound: %f\n",result.getUpperBound());System.out.printf("\tLower bound: %f\n",result.getLowerBound());System.out.printf("\tValue: %f\n",result.getValue());}}}}
/** * TODO(developer): Uncomment these variables before running the sample.\ * (Not necessary if passing values as arguments) */// const endpointId = 'YOUR_ENDPOINT_ID';// const project = 'YOUR_PROJECT_ID';// const location = 'YOUR_PROJECT_LOCATION';constaiplatform=require('@google-cloud/aiplatform');const{prediction}=aiplatform.protos.google.cloud.aiplatform.v1.schema.predict;// Imports the Google Cloud Prediction service clientconst{PredictionServiceClient}=aiplatform.v1;// Import the helper module for converting arbitrary protobuf.Value objects.const{helpers}=aiplatform;// Specifies the location of the api endpointconstclientOptions={apiEndpoint:'us-central1-aiplatform.googleapis.com',};// Instantiates a clientconstpredictionServiceClient=newPredictionServiceClient(clientOptions);asyncfunctionpredictTablesRegression(){// Configure the endpoint resourceconstendpoint=`projects/${project}/locations/${location}/endpoints/${endpointId}`;constparameters=helpers.toValue({});// TODO (erschmid): Make this less painfulconstinstance=helpers.toValue({BOOLEAN_2unique_NULLABLE:false,DATETIME_1unique_NULLABLE:'2019-01-01 00:00:00',DATE_1unique_NULLABLE:'2019-01-01',FLOAT_5000unique_NULLABLE:1611,FLOAT_5000unique_REPEATED:[2320,1192],INTEGER_5000unique_NULLABLE:'8',NUMERIC_5000unique_NULLABLE:16,STRING_5000unique_NULLABLE:'str-2',STRUCT_NULLABLE:{BOOLEAN_2unique_NULLABLE:false,DATE_1unique_NULLABLE:'2019-01-01',DATETIME_1unique_NULLABLE:'2019-01-01 00:00:00',FLOAT_5000unique_NULLABLE:1308,FLOAT_5000unique_REPEATED:[2323,1178],FLOAT_5000unique_REQUIRED:3089,INTEGER_5000unique_NULLABLE:'1777',NUMERIC_5000unique_NULLABLE:3323,TIME_1unique_NULLABLE:'23:59:59.999999',STRING_5000unique_NULLABLE:'str-49',TIMESTAMP_1unique_NULLABLE:'1546387199999999',},TIMESTAMP_1unique_NULLABLE:'1546387199999999',TIME_1unique_NULLABLE:'23:59:59.999999',});constinstances=[instance];constrequest={endpoint,instances,parameters,};// Predict requestconst[response]=awaitpredictionServiceClient.predict(request);console.log('Predict tabular regression response');console.log(`\tDeployed model id : ${response.deployedModelId}`);constpredictions=response.predictions;console.log('\tPredictions :');for(constpredictionResultValofpredictions){constpredictionResultObj=prediction.TabularRegressionPredictionResult.fromValue(predictionResultVal);console.log(`\tUpper bound: ${predictionResultObj.upper_bound}`);console.log(`\tLower bound: ${predictionResultObj.lower_bound}`);console.log(`\tLower bound: ${predictionResultObj.value}`);}}predictTablesRegression();
TabNet provides inherent model interpretability by giving users insight into
which features it used to help make its decision. The algorithm utilizes
attention, which learns to selectively enhance the influence of
some features while diminishing the influence of others through a weighted
average. For a particular decision, TabNet decides in a stepwise fashion how
much importance to place on each feature. It then combines each of the steps to
create a final prediction. The attention is multiplicative, where larger values
indicate that the feature played a larger role in the prediction and a value of
zero means that the feature played no role in that decision. Because TabNet
uses multiple decision steps, the attention placed on the features across all of
the steps are linearly combined after appropriate scaling. This linear
combination across all of TabNet's decision steps is the total feature
importance that TabNet provides you.
Example output for inferences
The return payload for an online inference with feature importance from a
regression model looks similar to the following example.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Hard to understand","hardToUnderstand","thumb-down"],["Incorrect information or sample code","incorrectInformationOrSampleCode","thumb-down"],["Missing the information/samples I need","missingTheInformationSamplesINeed","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2026-06-09 UTC."],[],[]]