diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..8574234 --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +COHERE_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index b04a8c8..9e698c0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ # rspec failure tracking .rspec_status + +.env diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a16cb1..298ac01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ ## [Unreleased] +- Migrate to v2 APIs ## [0.9.11] - 2024-08-01 - New `rerank()` method diff --git a/Gemfile b/Gemfile index 8c80c60..8696af1 100644 --- a/Gemfile +++ b/Gemfile @@ -9,3 +9,4 @@ gem "rake", "~> 13.0" gem "rspec", "~> 3.0" gem "standard", "~> 1.28.0" +gem "dotenv", "~> 2.8.1" diff --git a/Gemfile.lock b/Gemfile.lock index cc15050..e402c8b 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -9,6 +9,7 @@ GEM specs: ast (2.4.2) diff-lcs (1.5.0) + dotenv (2.8.1) faraday (2.7.10) faraday-net_http (>= 2.0, < 3.1) ruby2_keywords (>= 0.0.4) @@ -74,6 +75,7 @@ PLATFORMS DEPENDENCIES cohere-ruby! + dotenv (~> 2.8.1) rake (~> 13.0) rspec (~> 3.0) standard (~> 1.28.0) diff --git a/README.md b/README.md index a46174e..e8e3a45 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,12 @@ Cohere API client for Ruby. -Part of the [Langchain.rb](https://github.com/andreibondarev/langchainrb) stack. +Part of the [Langchain.rb](https://github.com/patterns-ai-core/langchainrb) stack. -![Tests status](https://github.com/andreibondarev/cohere-ruby/actions/workflows/ci.yml/badge.svg) +![Tests status](https://github.com/patterns-ai-core/cohere-ruby/actions/workflows/ci.yml/badge.svg) [![Gem Version](https://badge.fury.io/rb/cohere-ruby.svg)](https://badge.fury.io/rb/cohere-ruby) [![Docs](http://img.shields.io/badge/yard-docs-blue.svg)](http://rubydoc.info/gems/cohere-ruby) -[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/andreibondarev/cohere-ruby/blob/main/LICENSE.txt) +[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/patterns-ai-core/cohere-ruby/blob/main/LICENSE.txt) [![](https://dcbadge.vercel.app/api/server/WDARp7J2n8?compact=true&style=flat)](https://discord.gg/WDARp7J2n8) ## Installation @@ -50,14 +50,18 @@ client.generate( ```ruby client.chat( - message: "Hey! How are you?" + model: "command-r-plus-08-2024", + messages: [{role:"user", content: "Hey! How are you?"}] ) ``` `chat` supports a streaming option. You can pass a block to the `chat` method and it will yield a new chunk as soon as it is received. ```ruby -client.chat(message: "Hey! How are you?", stream: true) do |chunk, overall_received_bytes| +client.chat( + model: "command-r-plus-08-2024", + messages: [{role:"user", content: "Hey! How are you?"}] +) do |chunk, overall_received_bytes| puts "Received #{overall_received_bytes} bytes: #{chunk.force_encoding(Encoding::UTF_8)}" end ``` @@ -68,25 +72,25 @@ end ```ruby tools = [ - { - name: "query_daily_sales_report", - description: "Connects to a database to retrieve overall sales volumes and sales information for a given day.", - parameter_definitions: { - day: { - description: "Retrieves sales data for this day, formatted as YYYY-MM-DD.", - type: "str", - required: true - } - } - } + { + name: "query_daily_sales_report", + description: "Connects to a database to retrieve overall sales volumes and sales information for a given day.", + parameter_definitions: { + day: { + description: "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + type: "str", + required: true + } + } + } ] message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?" client.chat( model: model, - message: message, - tools: tools, + messages: [{ role:"user", content: message }], + tools: tools ) ``` @@ -94,7 +98,10 @@ client.chat( ```ruby client.embed( - texts: ["hello!"] + model: "embed-english-v3.0", + texts: ["hello", "goodbye"], + input_type: "classification", + embedding_types: ["float"] ) ``` @@ -110,6 +117,7 @@ docs = [ ] client.rerank( + model: "rerank-english-v3.0", query: "What is the capital of the United States?", documents: docs ) @@ -137,8 +145,9 @@ inputs = [ ] client.classify( - examples: examples, - inputs: inputs + model: "embed-multilingual-v2.0", + inputs: inputs, + examples: examples ) ``` @@ -146,7 +155,8 @@ client.classify( ```ruby client.tokenize( - text: "hello world!" + model: "command-r-plus-08-2024", + text: "Hello, world!" ) ``` @@ -154,7 +164,8 @@ client.tokenize( ```ruby client.detokenize( - tokens: [33555, 1114 , 34] + model: "command-r-plus-08-2024", + tokens: [33555, 1114, 34] ) ``` diff --git a/bin/console b/bin/console index f99ac76..8daab68 100755 --- a/bin/console +++ b/bin/console @@ -3,13 +3,7 @@ require "bundler/setup" require "cohere" - -# You can add fixtures and/or initialization code here to make experimenting -# with your gem easier. You can also use a different console, if you like. - -# (If you use this, don't forget to add pry to your Gemfile!) -# require "pry" -# Pry.start +require "dotenv/load" client = Cohere::Client.new( api_key: ENV['COHERE_API_KEY'] diff --git a/lib/cohere/client.rb b/lib/cohere/client.rb index 523e8c6..d617c4d 100644 --- a/lib/cohere/client.rb +++ b/lib/cohere/client.rb @@ -6,62 +6,56 @@ module Cohere class Client attr_reader :api_key, :connection - ENDPOINT_URL = "https://api.cohere.ai/v1" - def initialize(api_key:, timeout: nil) @api_key = api_key @timeout = timeout end + # Generates a text response to a user message and streams it down, token by token def chat( - message: nil, - model: nil, + model:, + messages:, stream: false, - preamble: nil, - preamble_override: nil, - chat_history: [], - conversation_id: nil, - prompt_truncation: nil, - connectors: [], - search_queries_only: false, + tools: [], documents: [], - citation_quality: nil, - temperature: nil, + citation_options: nil, + response_format: nil, + safety_mode: nil, max_tokens: nil, - k: nil, - p: nil, + stop_sequences: nil, + temperature: nil, seed: nil, frequency_penalty: nil, presence_penalty: nil, - tools: [], + k: nil, + p: nil, + logprops: nil, &block ) - response = connection.post("chat") do |req| + response = v2_connection.post("chat") do |req| req.body = {} - req.body[:message] = message if message - req.body[:model] = model if model - if stream || block - req.body[:stream] = true - req.options.on_data = block if block - end - req.body[:preamble] = preamble if preamble - req.body[:preamble_override] = preamble_override if preamble_override - req.body[:chat_history] = chat_history if chat_history - req.body[:conversation_id] = conversation_id if conversation_id - req.body[:prompt_truncation] = prompt_truncation if prompt_truncation - req.body[:connectors] = connectors if connectors - req.body[:search_queries_only] = search_queries_only if search_queries_only - req.body[:documents] = documents if documents - req.body[:citation_quality] = citation_quality if citation_quality - req.body[:temperature] = temperature if temperature + req.body[:model] = model + req.body[:messages] = messages if messages + req.body[:tools] = tools if tools.any? + req.body[:documents] = documents if documents.any? + req.body[:citation_options] = citation_options if citation_options + req.body[:response_format] = response_format if response_format + req.body[:safety_mode] = safety_mode if safety_mode req.body[:max_tokens] = max_tokens if max_tokens - req.body[:k] = k if k - req.body[:p] = p if p + req.body[:stop_sequences] = stop_sequences if stop_sequences + req.body[:temperature] = temperature if temperature req.body[:seed] = seed if seed req.body[:frequency_penalty] = frequency_penalty if frequency_penalty req.body[:presence_penalty] = presence_penalty if presence_penalty - req.body[:tools] = tools if tools + req.body[:k] = k if k + req.body[:p] = p if p + req.body[:logprops] = logprops if logprops + + if stream || block + req.body[:stream] = true + req.options.on_data = block if block + end end response.body end @@ -84,7 +78,7 @@ def generate( logit_bias: nil, truncate: nil ) - response = connection.post("generate") do |req| + response = v1_connection.post("generate") do |req| req.body = {prompt: prompt} req.body[:model] = model if model req.body[:num_generations] = num_generations if num_generations @@ -104,36 +98,44 @@ def generate( response.body end + # This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. def embed( - texts:, - model: nil, - input_type: nil, + model:, + input_type:, + embedding_types:, + texts: nil, + images: nil, truncate: nil ) - response = connection.post("embed") do |req| - req.body = {texts: texts} - req.body[:model] = model if model - req.body[:input_type] = input_type if input_type + response = v2_connection.post("embed") do |req| + req.body = { + model: model, + input_type: input_type, + embedding_types: embedding_types + } + req.body[:texts] = texts if texts + req.body[:images] = images if images req.body[:truncate] = truncate if truncate end response.body end + # This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. def rerank( + model:, query:, documents:, - model: nil, top_n: nil, rank_fields: nil, return_documents: nil, max_chunks_per_doc: nil ) - response = connection.post("rerank") do |req| + response = v2_connection.post("rerank") do |req| req.body = { + model: model, query: query, documents: documents } - req.body[:model] = model if model req.body[:top_n] = top_n if top_n req.body[:rank_fields] = rank_fields if rank_fields req.body[:return_documents] = return_documents if return_documents @@ -142,41 +144,44 @@ def rerank( response.body end + # This endpoint makes a prediction about which label fits the specified text inputs best. def classify( + model:, inputs:, - examples:, - model: nil, - present: nil, + examples: nil, + preset: nil, truncate: nil ) - response = connection.post("classify") do |req| + response = v1_connection.post("classify") do |req| req.body = { - inputs: inputs, - examples: examples + model: model, + inputs: inputs } - req.body[:model] = model if model - req.body[:present] = present if present + req.body[:examples] = examples if examples + req.body[:preset] = preset if preset req.body[:truncate] = truncate if truncate end response.body end - def tokenize(text:, model: nil) - response = connection.post("tokenize") do |req| - req.body = model.nil? ? {text: text} : {text: text, model: model} + # This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). + def tokenize(text:, model:) + response = v1_connection.post("tokenize") do |req| + req.body = {text: text, model: model} end response.body end - def detokenize(tokens:, model: nil) - response = connection.post("detokenize") do |req| - req.body = model.nil? ? {tokens: tokens} : {tokens: tokens, model: model} + # This endpoint takes tokens using byte-pair encoding and returns their text representation. + def detokenize(tokens:, model:) + response = v1_connection.post("detokenize") do |req| + req.body = {tokens: tokens, model: model} end response.body end def detect_language(texts:) - response = connection.post("detect-language") do |req| + response = v1_connection.post("detect-language") do |req| req.body = {texts: texts} end response.body @@ -191,7 +196,7 @@ def summarize( temperature: nil, additional_command: nil ) - response = connection.post("summarize") do |req| + response = v1_connection.post("summarize") do |req| req.body = {text: text} req.body[:length] = length if length req.body[:format] = format if format @@ -205,17 +210,22 @@ def summarize( private - # standard:disable Lint/DuplicateMethods - def connection - @connection ||= Faraday.new(url: ENDPOINT_URL, request: {timeout: @timeout}) do |faraday| - if api_key - faraday.request :authorization, :Bearer, api_key - end + def v1_connection + @connection ||= Faraday.new(url: "https://api.cohere.ai/v1", request: {timeout: @timeout}) do |faraday| + faraday.request :authorization, :Bearer, api_key + faraday.request :json + faraday.response :json, content_type: /\bjson$/ + faraday.adapter Faraday.default_adapter + end + end + + def v2_connection + @connection ||= Faraday.new(url: "https://api.cohere.com/v2", request: {timeout: @timeout}) do |faraday| + faraday.request :authorization, :Bearer, api_key faraday.request :json faraday.response :json, content_type: /\bjson$/ faraday.adapter Faraday.default_adapter end end - # standard:enable Lint/DuplicateMethods end end diff --git a/spec/cohere/client_spec.rb b/spec/cohere/client_spec.rb index 4e1a929..2b2d7cb 100644 --- a/spec/cohere/client_spec.rb +++ b/spec/cohere/client_spec.rb @@ -6,7 +6,7 @@ subject { described_class.new(api_key: "123") } describe "#generate" do - let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) } + let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate.json")) } let(:response) { OpenStruct.new(body: generate_result) } before do @@ -22,8 +22,26 @@ end end + describe "#chat" do + let(:generate_result) { JSON.parse(File.read("spec/fixtures/chat.json")) } + let(:response) { OpenStruct.new(body: generate_result) } + + before do + allow_any_instance_of(Faraday::Connection).to receive(:post) + .with("chat") + .and_return(response) + end + + it "returns a response" do + expect(subject.chat( + model: "command-r-plus-08-2024", + messages: [{role: "user", content: "Hey! How are you?"}] + ).dig("message", "content", 0, "text")).to eq("I'm doing well, thank you for asking! As an AI language model, I don't have emotions or feelings, but I'm designed to provide helpful and informative responses to assist you in the best way I can. Is there anything you'd like to know or discuss today?") + end + end + describe "#embed" do - let(:embed_result) { JSON.parse(File.read("spec/fixtures/embed_result.json")) } + let(:embed_result) { JSON.parse(File.read("spec/fixtures/embed.json")) } let(:response) { OpenStruct.new(body: embed_result) } before do @@ -34,8 +52,11 @@ it "returns a response" do expect(subject.embed( - texts: ["hello!"] - ).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]]) + model: "embed-english-v3.0", + texts: ["hello", "goodbye"], + input_type: "classification", + embedding_types: ["float"] + ).dig("embeddings", "float")).to eq([[0.017578125, -0.009162903, -0.046325684]]) end end @@ -61,7 +82,7 @@ it "returns a response" do expect( subject - .rerank(query: "What is the capital of the United States?", documents: docs) + .rerank(model: "rerank-english-v3.0", query: "What is the capital of the United States?", documents: docs) .dig("results") .map { |h| h["index"] } ).to eq([3, 4, 2, 0, 1]) @@ -69,17 +90,19 @@ end describe "#classify" do - let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) } + let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify.json")) } let(:response) { OpenStruct.new(body: classify_result) } - let(:inputs) { + let(:examples) { [ {text: "Dermatologists don't like her!", label: "Spam"}, - {text: "Hello, open to this?", label: "Spam"} + {text: "Hello, open to this?", label: "Spam"}, + {text: "Your parcel will be delivered today", label: "Not spam"}, + {text: "Review changes to our Terms and Conditions", label: "Not spam"} ] } - let(:examples) { + let(:inputs) { [ "Confirm your email address", "hey i need u to send some $" @@ -94,6 +117,7 @@ it "returns a response" do res = subject.classify( + model: "embed-multilingual-v2.0", inputs: inputs, examples: examples ).dig("classifications") @@ -104,7 +128,7 @@ end describe "#tokenize" do - let(:tokenize_result) { JSON.parse(File.read("spec/fixtures/tokenize_result.json")) } + let(:tokenize_result) { JSON.parse(File.read("spec/fixtures/tokenize.json")) } let(:response) { OpenStruct.new(body: tokenize_result) } before do @@ -115,31 +139,14 @@ it "returns a response" do expect(subject.tokenize( + model: "command-r-plus-08-2024", text: "Hello, world!" ).dig("tokens")).to eq([33555, 1114, 34]) end end - describe "#tokenize_with_model" do - let(:tokenize_result) { JSON.parse(File.read("spec/fixtures/tokenize_result.json")) } - let(:response) { OpenStruct.new(body: tokenize_result) } - - before do - allow_any_instance_of(Faraday::Connection).to receive(:post) - .with("tokenize") - .and_return(response) - end - - it "returns a response" do - expect(subject.tokenize( - text: "Hello, world!", - model: "base" - ).dig("tokens")).to eq([33555, 1114, 34]) - end - end - describe "#detokenize" do - let(:detokenize_result) { JSON.parse(File.read("spec/fixtures/detokenize_result.json")) } + let(:detokenize_result) { JSON.parse(File.read("spec/fixtures/detokenize.json")) } let(:response) { OpenStruct.new(body: detokenize_result) } before do @@ -150,31 +157,14 @@ it "returns a response" do expect(subject.detokenize( + model: "command-r-plus-08-2024", tokens: [33555, 1114, 34] ).dig("text")).to eq("hello world!") end end - describe "#detokenize_with_model" do - let(:detokenize_result) { JSON.parse(File.read("spec/fixtures/detokenize_result.json")) } - let(:response) { OpenStruct.new(body: detokenize_result) } - - before do - allow_any_instance_of(Faraday::Connection).to receive(:post) - .with("detokenize") - .and_return(response) - end - - it "returns a response" do - expect(subject.detokenize( - tokens: [33555, 1114, 34], - model: "base" - ).dig("text")).to eq("hello world!") - end - end - describe "#detect_language" do - let(:detect_language_result) { JSON.parse(File.read("spec/fixtures/detect_language_result.json")) } + let(:detect_language_result) { JSON.parse(File.read("spec/fixtures/detect_language.json")) } let(:response) { OpenStruct.new(body: detect_language_result) } before do @@ -191,7 +181,7 @@ end describe "#summarize" do - let(:summarize_result) { JSON.parse(File.read("spec/fixtures/summarize_result.json")) } + let(:summarize_result) { JSON.parse(File.read("spec/fixtures/summarize.json")) } let(:response) { OpenStruct.new(body: summarize_result) } before do diff --git a/spec/fixtures/chat.json b/spec/fixtures/chat.json new file mode 100644 index 0000000..1974c7e --- /dev/null +++ b/spec/fixtures/chat.json @@ -0,0 +1,23 @@ +{ + "id": "731bae88-f610-45a4-8f75-2b44a7856388", + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "I'm doing well, thank you for asking! As an AI language model, I don't have emotions or feelings, but I'm designed to provide helpful and informative responses to assist you in the best way I can. Is there anything you'd like to know or discuss today?" + } + ] + }, + "finish_reason": "COMPLETE", + "usage": { + "billed_units": { + "input_tokens": 6, + "output_tokens": 56 + }, + "tokens": { + "input_tokens": 207, + "output_tokens": 56 + } + } +} diff --git a/spec/fixtures/classify.json b/spec/fixtures/classify.json new file mode 100644 index 0000000..fd7996c --- /dev/null +++ b/spec/fixtures/classify.json @@ -0,0 +1,35 @@ +{ + "id": "dd6484b4-952a-45d1-8e2a-1fbcadf1a3bc", + "classifications": [ + { + "classification_type": "single-label", + "confidence": 0.66766936, + "confidences": [0.66766936], + "id": "5abf9fea-dcb1-42cb-8bc7-e797c80f2293", + "input": "Confirm your email address", + "labels": { + "Not spam": { "confidence": 0.66766936 }, + "Spam": { "confidence": 0.33233067 } + }, + "prediction": "Not spam", + "predictions": ["Not spam"] + }, + { + "classification_type": "single-label", + "confidence": 0.5345887, + "confidences": [0.5345887], + "id": "75ddba7e-466d-408a-aca2-78165e6f1dfa", + "input": "hey i need u to send some $", + "labels": { + "Not spam": { "confidence": 0.46541128 }, + "Spam": { "confidence": 0.5345887 } + }, + "prediction": "Spam", + "predictions": ["Spam"] + } + ], + "meta": { + "api_version": { "version": "2" }, + "billed_units": { "classifications": 2 } + } +} diff --git a/spec/fixtures/classify_result.json b/spec/fixtures/classify_result.json deleted file mode 100644 index 5fd6576..0000000 --- a/spec/fixtures/classify_result.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "id": "7b8be981-cb89-4d16-9a2c-dec3a2b0f71d", - "classifications": [ - { - "id": "c33e7cf7-1d5f-41f8-916b-20107b73fca8", - "input": "Confirm your email address", - "prediction": "Not spam", - "confidence": 0.80833024, - "labels": { - "Not spam": { - "confidence": 0.80833024 - }, - "Spam": { - "confidence": 0.19166975 - } - } - }, - { - "id": "18384c67-9cab-4960-9fc3-ca577586701b", - "input": "hey i need u to send some $", - "prediction": "Spam", - "confidence": 0.9893047, - "labels": { - "Not spam": { - "confidence": 0.010695281 - }, - "Spam": { - "confidence": 0.9893047 - } - } - } - ], - "meta": { - "api_version": { - "version": "1" - } - } -} \ No newline at end of file diff --git a/spec/fixtures/detect_language_result.json b/spec/fixtures/detect_language.json similarity index 100% rename from spec/fixtures/detect_language_result.json rename to spec/fixtures/detect_language.json diff --git a/spec/fixtures/detokenize_result.json b/spec/fixtures/detokenize.json similarity index 100% rename from spec/fixtures/detokenize_result.json rename to spec/fixtures/detokenize.json diff --git a/spec/fixtures/embed.json b/spec/fixtures/embed.json new file mode 100644 index 0000000..1b4698e --- /dev/null +++ b/spec/fixtures/embed.json @@ -0,0 +1,22 @@ +{ + "id": "420351e8-81e6-490d-8ca1-f320911c5fde", + "texts": ["hello", "goodbye"], + "embeddings": { + "float": [ + [ + 0.017578125, + -0.009162903, + -0.046325684 + ] + ] + }, + "meta": { + "api_version": { + "version": "2" + }, + "billed_units": { + "input_tokens": 2 + } + }, + "response_type": "embeddings_by_type" +} \ No newline at end of file diff --git a/spec/fixtures/embed_result.json b/spec/fixtures/embed_result.json deleted file mode 100644 index 5e92de0..0000000 --- a/spec/fixtures/embed_result.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "id": "83b78a65-8ede-44ed-a8dc-6afb3ad4a78b", - "texts": ["hello!"], - "embeddings": [[ 1.2177734, 0.67529297, 2.0742188 ]], - "meta": { - "api_version": { - "version": "1" - } - } -} diff --git a/spec/fixtures/generate_result.json b/spec/fixtures/generate.json similarity index 100% rename from spec/fixtures/generate_result.json rename to spec/fixtures/generate.json diff --git a/spec/fixtures/summarize_result.json b/spec/fixtures/summarize.json similarity index 100% rename from spec/fixtures/summarize_result.json rename to spec/fixtures/summarize.json diff --git a/spec/fixtures/tokenize_result.json b/spec/fixtures/tokenize.json similarity index 100% rename from spec/fixtures/tokenize_result.json rename to spec/fixtures/tokenize.json