Skip to content

Commit

Permalink
[functions][state] Python state support (apache#2714)
Browse files Browse the repository at this point in the history
*Motivation*

Add state support in python functions

*Changes*

- Bump bookkeeper version, so the table service has the changes to support python functions
- Add state to python function
  • Loading branch information
sijie authored Feb 11, 2019
1 parent bfcc7c7 commit 21b8c9d
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 14 deletions.
20 changes: 20 additions & 0 deletions pulsar-client-cpp/python/pulsar/functions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,23 @@ def get_output_serde_class_name(self):
def ack(self, msgid, topic):
"""ack this message id"""
pass

@abstractmethod
def incr_counter(self, key, amount):
"""incr the counter of a given key in the managed state"""
pass

@abstractmethod
def get_counter(self, key):
"""get the counter of a given key in the managed state"""
pass

@abstractmethod
def put_state(self, key, value):
"""update the value of a given key in the managed state"""
pass

@abstractmethod
def get_state(self, key):
"""get the value of a given key in the managed state"""
pass
6 changes: 4 additions & 2 deletions pulsar-client-cpp/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def build_extension(self, ext):


dependencies = [
'grpcio', 'protobuf',
'six',
'fastavro',
'grpcio',
'protobuf',
'six',

# functions dependencies
"apache-bookkeeper-client",
"prometheus_client",
"ratelimit"
]
Expand Down
15 changes: 14 additions & 1 deletion pulsar-functions/instance/src/main/python/contextimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ class ContextImpl(pulsar.Context):
# add label to indicate user metric
user_metrics_label_names = Stats.metrics_label_names + ["metric"]

def __init__(self, instance_config, logger, pulsar_client, user_code, consumers, secrets_provider, metrics_labels):
def __init__(self, instance_config, logger, pulsar_client, user_code, consumers, secrets_provider, metrics_labels, state_context):
self.instance_config = instance_config
self.log = logger
self.pulsar_client = pulsar_client
self.user_code_dir = os.path.dirname(user_code)
self.consumers = consumers
self.secrets_provider = secrets_provider
self.state_context = state_context
self.publish_producers = {}
self.publish_serializers = {}
self.message = None
Expand Down Expand Up @@ -186,3 +187,15 @@ def get_metrics(self):
metrics_map["%s%s_count" % (Stats.USER_METRIC_PREFIX, metric_name)] = user_metric._count.get()

return metrics_map

def incr_counter(self, key, amount):
return self.state_context.incr(key, amount)

def get_counter(self, key):
return self.state_context.get_amount(key)

def put_state(self, key, value):
return self.state_context.put(key, value)

def get_state(self, key):
return self.state_context.get_value(key)
30 changes: 27 additions & 3 deletions pulsar-functions/instance/src/main/python/python_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import util
import InstanceCommunication_pb2

# state dependencies
import state_context

from functools import partial
from collections import namedtuple
from function_stats import Stats
Expand Down Expand Up @@ -67,15 +70,26 @@ def base64ify(bytes_or_str):
return output_bytes

class PythonInstance(object):
def __init__(self, instance_id, function_id, function_version, function_details, max_buffered_tuples,
expected_healthcheck_interval, user_code, pulsar_client, secrets_provider, cluster_name):
def __init__(self,
instance_id,
function_id,
function_version,
function_details,
max_buffered_tuples,
expected_healthcheck_interval,
user_code,
pulsar_client,
secrets_provider,
cluster_name,
state_storage_serviceurl):
self.instance_config = InstanceConfig(instance_id, function_id, function_version, function_details, max_buffered_tuples)
self.user_code = user_code
self.queue = queue.Queue(max_buffered_tuples)
self.log_topic_handler = None
if function_details.logTopic is not None and function_details.logTopic != "":
self.log_topic_handler = log.LogTopicHandler(str(function_details.logTopic), pulsar_client)
self.pulsar_client = pulsar_client
self.state_storage_serviceurl = state_storage_serviceurl
self.input_serdes = {}
self.consumers = {}
self.output_serde = None
Expand All @@ -91,6 +105,7 @@ def __init__(self, instance_id, function_id, function_version, function_details,
self.timeout_ms = function_details.source.timeoutMs if function_details.source.timeoutMs > 0 else None
self.expected_healthcheck_interval = expected_healthcheck_interval
self.secrets_provider = secrets_provider
self.state_context = state_context.NullStateContext()
self.metrics_labels = [function_details.tenant,
"%s/%s" % (function_details.tenant, function_details.namespace),
function_details.name,
Expand All @@ -111,6 +126,9 @@ def process_spawner_health_check_timer(self):
sys.exit(1)

def run(self):
# Setup state
self.state_context = self.setup_state()

# Setup consumers and input deserializers
mode = pulsar._pulsar.ConsumerType.Shared
if self.instance_config.function_details.source.subscriptionType == Function_pb2.SubscriptionType.Value("FAILOVER"):
Expand Down Expand Up @@ -176,7 +194,7 @@ def run(self):

self.contextimpl = contextimpl.ContextImpl(self.instance_config, Log, self.pulsar_client,
self.user_code, self.consumers,
self.secrets_provider, self.metrics_labels)
self.secrets_provider, self.metrics_labels, self.state_context)
# Now launch a thread that does execution
self.execution_thread = threading.Thread(target=self.actual_execution)
self.execution_thread.start()
Expand Down Expand Up @@ -287,6 +305,12 @@ def setup_producer(self):
self.instance_config.instance_id)
)

def setup_state(self):
table_ns = "%s_%s" % (str(self.instance_config.function_details.tenant),
str(self.instance_config.function_details.namespace))
table_name = str(self.instance_config.function_details.name)
return state_context.create_state_context(self.state_storage_serviceurl, table_ns, table_name)

def message_listener(self, serde, consumer, message):
# increment number of received records from source
self.stats.incr_total_received()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import prometheus_client_fix

from google.protobuf import json_format
from bookkeeper.kv.client import Client

to_run = True
Log = log.Log
Expand Down Expand Up @@ -84,6 +85,7 @@ def main():
parser.add_argument('--install_usercode_dependencies', required=False, help='For packaged python like wheel files, do we need to install all dependencies', type=bool)
parser.add_argument('--dependency_repository', required=False, help='For packaged python like wheel files, which repository to pull the dependencies from')
parser.add_argument('--extra_dependency_repository', required=False, help='For packaged python like wheel files, any extra repository to pull the dependencies from')
parser.add_argument('--state_storage_serviceurl', required=False, help='Managed State Storage Service Url')
parser.add_argument('--cluster_name', required=True, help='The name of the cluster this instance is running on')

args = parser.parse_args()
Expand Down Expand Up @@ -158,6 +160,10 @@ def main():
tls_trust_cert_path = args.tls_trust_cert_path
pulsar_client = pulsar.Client(args.pulsar_serviceurl, authentication, 30, 1, 1, 50000, None, use_tls, tls_trust_cert_path, tls_allow_insecure_connection)

state_storage_serviceurl = None
if args.state_storage_serviceurl is not None:
state_storage_serviceurl = str(args.state_storage_serviceurl)

secrets_provider = None
if args.secrets_provider is not None:
secrets_provider = util.import_class(os.path.dirname(inspect.getfile(inspect.currentframe())), str(args.secrets_provider))
Expand All @@ -178,7 +184,11 @@ def main():
str(args.function_version), function_details,
int(args.max_buffered_tuples),
int(args.expected_healthcheck_interval),
str(args.py), pulsar_client, secrets_provider, args.cluster_name)
str(args.py),
pulsar_client,
secrets_provider,
args.cluster_name,
state_storage_serviceurl)
pyinstance.run()
server_instance = server.serve(args.port, pyinstance)

Expand Down
156 changes: 156 additions & 0 deletions pulsar-functions/instance/src/main/python/state_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

# -*- encoding: utf-8 -*-

"""state_context.py: state context for accessing managed state
"""
from abc import abstractmethod
from bookkeeper import admin, kv
from bookkeeper.common.exceptions import NamespaceNotFoundError, StreamNotFoundError, KeyNotFoundError
from bookkeeper.proto import stream_pb2
from bookkeeper.proto.stream_pb2 import HASH
from bookkeeper.proto.stream_pb2 import TABLE
from bookkeeper.types import StorageClientSettings


def new_bk_table_conf(num_ranges):
"""Create a table configuration with the specified `num_ranges`"""
return stream_pb2.StreamConfiguration(
key_type=HASH,
min_num_ranges=num_ranges,
initial_num_ranges=num_ranges,
split_policy=stream_pb2.SplitPolicy(
type=stream_pb2.SplitPolicyType.values()[0],
fixed_range_policy=stream_pb2.FixedRangeSplitPolicy(
num_ranges=2
)
),
rolling_policy=stream_pb2.SegmentRollingPolicy(
size_policy=stream_pb2.SizeBasedSegmentRollingPolicy(
max_segment_size=128 * 1024 * 1024
)
),
retention_policy=stream_pb2.RetentionPolicy(
time_policy=stream_pb2.TimeBasedRetentionPolicy(
retention_minutes=-1
)
),
storage_type=TABLE
)


def create_state_context(state_storage_serviceurl, table_ns, table_name):
"""Create the state context based on state storage serviceurl"""
if state_storage_serviceurl is None:
return NullStateContext()
else:
return BKManagedStateContext(state_storage_serviceurl, table_ns, table_name)


class StateContext(object):
"""Interface defining operations on managed state"""

@abstractmethod
def incr(self, key, amount):
pass

@abstractmethod
def put(self, key, value):
pass

@abstractmethod
def get_value(self, key):
pass

@abstractmethod
def get_amount(self, key):
pass


class NullStateContext(StateContext):
"""A state context that does nothing"""

def incr(self, key, amount):
return

def put(self, key, value):
return

def get_value(self, key):
return None

def get_amount(self, key):
return None


class BKManagedStateContext(StateContext):
"""A state context that access bookkeeper managed state"""

def __init__(self, state_storage_serviceurl, table_ns, table_name):
client_settings = StorageClientSettings(
service_uri=state_storage_serviceurl)
admin_client = admin.client.Client(
storage_client_settings=client_settings)
# create namespace and table if needed
ns = admin_client.namespace(table_ns)
try:
ns.get(stream_name=table_name)
except NamespaceNotFoundError:
admin_client.namespaces().create(namespace=table_ns)
# TODO: make number of table ranges configurable
table_conf = new_bk_table_conf(1)
ns.create(
stream_name=table_name,
stream_config=table_conf)
except StreamNotFoundError:
# TODO: make number of table ranges configurable
table_conf = new_bk_table_conf(1)
ns.create(
stream_name=table_name,
stream_config=table_conf)
self.__client__ = kv.Client(namespace=table_ns)
self.__table__ = self.__client__.table(table_name=table_name)

def incr(self, key, amount):
return self.__table__.incr_str(key, amount)

def get_amount(self, key):
try:
kv = self.__table__.get_str(key)
if kv is not None:
return kv.number_value
else:
return None
except KeyNotFoundError:
return None

def get_value(self, key):
try:
kv = self.__table__.get_str(key)
if kv is not None:
return kv.value
else:
return None
except KeyNotFoundError:
return None

def put(self, key, value):
return self.__table__.put_str(key, value)
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_context_publish(self):
pulsar_client.create_producer = Mock(return_value=producer)
user_code=__file__
consumers = None
context_impl = ContextImpl(instance_config, logger, pulsar_client, user_code, consumers, None, None)
context_impl = ContextImpl(instance_config, logger, pulsar_client, user_code, consumers, None, None, None)

context_impl.publish("test_topic_name", "test_message")

Expand Down
34 changes: 34 additions & 0 deletions pulsar-functions/python-examples/wordcount_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#


from pulsar import Function

# The classic ExclamationFunction that appends an exclamation at the end
# of the input
class WordCountFunction(Function):
def __init__(self):
pass

def process(self, input, context):
words = input.split()
for word in words:
context.incr_counter(word, 1)
return input + "!"
Loading

0 comments on commit 21b8c9d

Please sign in to comment.