ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


An Estimator for K-Means clustering.

Inherits From: Estimator


import numpy as np
import tensorflow as tf

num_points = 100
dimensions = 2
points = np.random.uniform(0, 1000, [num_points, dimensions])

def input_fn():
  return tf.compat.v1.train.limit_epochs(
      tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)

num_clusters = 5
kmeans = tf.compat.v1.estimator.experimental.KMeans(
    num_clusters=num_clusters, use_mini_batch=False)

# train
num_iterations = 10
previous_centers = None
for _ in xrange(num_iterations):
  cluster_centers = kmeans.cluster_centers()
  if previous_centers is not None:
    print 'delta:', cluster_centers - previous_centers
  previous_centers = cluster_centers
  print 'score:', kmeans.score(input_fn)
print 'cluster centers:', cluster_centers

# map the input points to their clusters
cluster_indices = list(kmeans.predict_cluster_index(input_fn))
for i, point in enumerate(points):
  cluster_index = cluster_indices[i]
  center = cluster_centers[cluster_index]
  print 'point:', point, 'is in cluster', cluster_index, 'centered at', center

The SavedModel saved by the export_saved_model method does not include the cluster centers. However, the cluster centers may be retrieved by the latest checkpoint saved during training. Specifically,


is equivalent to

    kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME)

num_clusters An integer tensor specifying the number of clusters. This argument is ignored if initial_clusters is a tensor or numpy array.
model_dir The directory to save the model results and log files.
initial_clusters Specifies how the initial cluster centers are chosen. One of the following: * a tensor or numpy array with the initial cluster centers. * a callable f(inputs, k) that selects and returns up to k centers from an input batch. f is free to return any number of centers from 0 to k. It will be invoked on successive input batches as necessary until all num_clusters centers are chosen.

  • KMeansClustering.RANDOM_INIT: Choose centers randomly from an input batch. If the batch size is less than num_clusters then the entire batch is chosen to be initial cluster centers and the remaining centers are chosen from successive input batches.
  • KMeansClustering.KMEANS_PLUS_PLUS_INIT: Use kmeans++ to choose centers from the first input batch. If the batch size is less than num_clusters, a TensorFlow runtime error occurs.
distance_metric The distance metric used for clustering. One of:
  • KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE: Euclidean distance between vectors u and v is defined as \(||u - v||_2\) which is the square root of the sum of the absolute squares of the elements' difference.
  • KMeansClustering.COSINE_DISTANCE: Cosine distance between vectors u and v is defined as \(1 - (u . v) / (||u||_2 ||v||_2)\).
  • seed Python integer. Seed for PRNG used to initialize centers.
    use_mini_batch A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above.
    mini_batch_steps_per_iteration The number of steps after which the updated cluster centers are synced back to a master copy. Used only if use_mini_batch=True. See explanation above.
    kmeans_plus_plus_num_retries For each point that is sampled during kmeans++ initialization, this parameter specifies the number of additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. Used only if initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT.