Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions app/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,11 @@ class PromptGenerationRequestSerializer(serializers.Serializer):
prompt = serializers.CharField(min_length=1, max_length=10000)


class PromptStreamingRequestSerializer(serializers.Serializer):
conversation = serializers.ListField()
system_prompt = serializers.CharField(required=False)


# Used for incoming requests to copy a widget instance. Does NOT map to a model.
class WidgetInstanceCopyRequestSerializer(serializers.Serializer):
new_name = serializers.ModelField(
Expand Down
1 change: 1 addition & 0 deletions app/api/urls/api_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
# AI generation
path("generate/qset/", generation.GenerateQsetView.as_view()),
path("generate/from_prompt/", generation.GenerateFromPromptView.as_view()),
path("generate/streaming/", generation.GenerateStreamingResponseView.as_view()),
path("lti/<slug:context_id>/instances/", LtiWidgetInstancesInCourseView.as_view()),
]
63 changes: 49 additions & 14 deletions app/api/views/generation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import logging
import re

from rest_framework.response import Response
from rest_framework.views import APIView

from api.permissions import CanCreateWidgetInstances
from api.serializers import (
QsetGenerationRequestSerializer,
PromptGenerationRequestSerializer,
PromptStreamingRequestSerializer,
QsetGenerationRequestSerializer,
)
from core.services.generator_service import GenerationUtil
from core.message_exception import MsgNoLogin, MsgInvalidInput, MsgFailure
from core.message_exception import MsgFailure, MsgInvalidInput, MsgNoLogin
from django.http import StreamingHttpResponse
from generation.core import GenerationCore
from generation.factory import GenerationDriverFactory
from rest_framework.response import Response
from rest_framework.views import APIView

logger = logging.getLogger(__name__)


class GenerateQsetView(APIView):
Expand All @@ -18,7 +23,7 @@ class GenerateQsetView(APIView):

def post(self, request):
# Check if generation is available
if not GenerationUtil.is_enabled():
if not GenerationCore.is_enabled():
raise MsgFailure(
msg="AI generation is not enabled on this instance of Materia"
)
Expand Down Expand Up @@ -52,19 +57,20 @@ def post(self, request):
if num_questions > 32:
num_questions = 32

# Generate qset
result = GenerationUtil.generate_qset(
# Get the appropriate driver and generate qset
driver = GenerationDriverFactory.get_driver()
result = driver.generate_qset(
widget=widget,
instance=widget_instance,
topic=topic,
num_questions=num_questions,
build_off_existing=build_off_existing,
instance=widget_instance,
)

# Return generated qset
return Response(
{
**result,
"qset": result,
"title": topic,
}
)
Expand All @@ -76,7 +82,7 @@ class GenerateFromPromptView(APIView):

def post(self, request):
# Check if generation is available
if not GenerationUtil.is_enabled():
if not GenerationCore.is_enabled():
raise MsgFailure(
msg="AI generation is not enabled on this instance of Materia"
)
Expand All @@ -87,11 +93,40 @@ def post(self, request):

prompt = request_serializer.validated_data["prompt"]

# Perform generation
result = GenerationUtil.generate_from_prompt(prompt)
# Get the appropriate driver and perform generation
driver = GenerationDriverFactory.get_driver()
result = driver.query_sync(prompt)

return Response(
{
"success": True,
"response": result,
}
)


class GenerateStreamingResponseView(APIView):
http_method_names = ["post"]
permission_classes = [CanCreateWidgetInstances]

def post(self, request):
# Check if generation is available
if not GenerationCore.is_enabled():
raise MsgFailure(
msg="AI generation is not enabled on this instance of Materia"
)

request_serializer = PromptStreamingRequestSerializer(data=request.data)
request_serializer.is_valid(raise_exception=True)

messages = request_serializer.validated_data["conversation"]
system_prompt = request_serializer.validated_data["system_prompt"]

driver = GenerationDriverFactory.get_driver()
response = StreamingHttpResponse(
driver.generate_prompt_stream(messages, system_prompt),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response
39 changes: 39 additions & 0 deletions app/core/services/boto_session_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging

import boto3
from django.conf import settings

logger = logging.getLogger(__name__)


class BotoSessionService:

# TODO: move client-related credential configs to another configuration location
# instead of just s3?

@staticmethod
def get_session():
# Configure credentials depending on whether we're providing them from env or Amazon's IMDSv2 service
# IMDS is HIGHLY recommended for prod usage on AWS
session = None
s = settings.AWS_SETTINGS
if s["credential_provider"] == "imds":
# Credentials are sourced from the EC2 instance's IAM role
session = boto3.Session()
elif s["credential_provider"] == "env":
session_config = {
"region_name": s["region"],
}
if s["profile"] is not None:
session_config["profile_name"] = s["profile"]
else:
session_config["aws_access_key_id"] = (s["key"],)
session_config["aws_secret_access_key"] = (s["secret_key"],)
session_config["aws_session_token"] = (s["aws_session_token"],)
session = boto3.Session(**session_config)
else:
raise Exception(
"boto3: Failed to determine credential provider. Did you set the appropriate environment variable?"
)

return session
Loading
Loading